Unverified Commit dda23993 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS async broadcast (#78)

Changes the broadcast calls in the OSS step() function to make them asynchronous
parent df11eaa2
...@@ -106,11 +106,14 @@ class OSS(Optimizer): ...@@ -106,11 +106,14 @@ class OSS(Optimizer):
# Run the optimizer step on this shard only # Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
# Sync all the states # Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
for rank, param_groups in enumerate(self.partition_parameters()): for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: for param in param_group["params"]:
dist.broadcast(tensor=param, src=rank, group=self.group) requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))
_ = list(map(lambda x: x.wait(), requests))
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment