Unverified Commit 11beea69 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][minor] ShardedDDP micro-optim (#296)

* minor, not life changing but removing a dependency on runtime optim
parent 3d02f052
......@@ -68,7 +68,7 @@ class ShardedDataParallel(nn.Module):
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
......@@ -185,7 +185,7 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad /= self.world_size
param.grad.mul_(self.world_size_scaling)
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment