Unverified Commit 2eee136f authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] re-run black to fix CPU tests on master (#123)

parent 379c6bf0
...@@ -137,7 +137,7 @@ class OSS(Optimizer): ...@@ -137,7 +137,7 @@ class OSS(Optimizer):
@property @property
def param_to_rank(self) -> Dict[torch.Tensor, int]: def param_to_rank(self) -> Dict[torch.Tensor, int]:
'''param to data parallel rank''' """param to data parallel rank"""
if len(self._param_rank) == 0: if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()): for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups: for param_group in param_groups:
...@@ -145,11 +145,11 @@ class OSS(Optimizer): ...@@ -145,11 +145,11 @@ class OSS(Optimizer):
self._param_rank[param] = rank self._param_rank[param] = rank
return self._param_rank return self._param_rank
def get_global_rank(self, group, rank): def get_global_rank(self, group: Any, rank: int) -> int:
if group is dist.group.WORLD: if group is dist.group.WORLD:
return rank return rank
else: else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank return global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...@@ -376,7 +376,9 @@ class OSS(Optimizer): ...@@ -376,7 +376,9 @@ class OSS(Optimizer):
logging.debug( logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank, "Sending the sharded optimizer state to the reference replica from rank %s", rank,
) )
broadcast_object(self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device) broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else: else:
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing # Discard this tensor/rank, broadcast necessary for syncing
...@@ -393,4 +395,3 @@ class OSS(Optimizer): ...@@ -393,4 +395,3 @@ class OSS(Optimizer):
for p in partition: for p in partition:
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment