Unverified Commit 1a636557 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[minor] ShardGradScaler - only wait for the last handle (#382)

* super minor, opportunistic micro optim
parent ce9e7e48
......@@ -49,10 +49,11 @@ class ShardedGradScaler(TorchGradScaler):
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [
dist.all_reduce(v, async_op=True, group=self.group)
for v in optimizer_state["found_inf_per_device"].values()
]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
last_handle = None
for v in optimizer_state["found_inf_per_device"].values():
last_handle = dist.all_reduce(v, async_op=True, group=self.group)
# Make sure that the calls are done before moving out.
# The calls are executed in sequence, waiting for the last one is enough
if last_handle is not None:
last_handle.wait()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment