"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "53043d26d837816d125ca8eb02ccfe5eaf0892e0"
Unverified Commit dd441e9d authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP & OSS, small improvements (#321)

* Couple of small improvements, no logic changes
parent bd5d0496
...@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module): ...@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
""" """
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass") logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
@torch.no_grad()
def sync_buffers(self, blocking: bool = False) -> None: def sync_buffers(self, blocking: bool = False) -> None:
""" """
Sync all the param buffers in between ranks (including for instance batch norm statistics). Sync all the param buffers in between ranks (including for instance batch norm statistics).
""" """
with torch.no_grad():
work_handles = [
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
for buffer in self.module.buffers(recurse=True)
]
if blocking: last_work_handle = None
_ = list(map(lambda x: x.wait(), work_handles))
for buffer in self.module.buffers(recurse=True):
last_work_handle = dist.broadcast(
buffer.data, self.reference_global_rank, self.process_group, async_op=True
)
if blocking and last_work_handle:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle.wait()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
...@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module): ...@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
yield yield
self.should_accumulate_grads = old_should_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad()
def _clear_counters(self) -> None: def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters""" """Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
...@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module): ...@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback. Either way a delayed action is necessary and is passed as a callback.
""" """
@torch.no_grad()
def reduce(*_: Any) -> None: def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
...@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module): ...@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer)) grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope self._grad_accs.append(grad_acc) # keep this function in scope
@torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
""" """
Sync the complete model states in between the ranks Sync the complete model states in between the ranks
""" """
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles)) last_work_handle = None
for t in self.module.state_dict().values():
last_work_handle = dist.broadcast(
t, src=self.reference_global_rank, group=self.process_group, async_op=True
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if last_work_handle:
last_work_handle.wait()
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None: def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
""" """
......
...@@ -534,6 +534,7 @@ class OSS(Optimizer): ...@@ -534,6 +534,7 @@ class OSS(Optimizer):
global_rank = dist.distributed_c10d._get_global_rank(group, rank) global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank return global_rank
@torch.no_grad()
def _sync_param_groups(self, local_to_global: bool = False) -> None: def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers). """Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the If the global param groups have been altered, and we want to make sure that the
...@@ -548,10 +549,12 @@ class OSS(Optimizer): ...@@ -548,10 +549,12 @@ class OSS(Optimizer):
elif k in global_group.keys(): elif k in global_group.keys():
local_group[k] = global_group[k] local_group[k] = global_group[k]
@torch.no_grad()
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
i_param = 0 i_param = 0
last_work_handle = None # Work handles are consumed within this scope, no callback
for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks) for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device] buckets = self.buckets[device]
...@@ -562,25 +565,18 @@ class OSS(Optimizer): ...@@ -562,25 +565,18 @@ class OSS(Optimizer):
# Direct broadcasts only # Direct broadcasts only
for param in params: for param in params:
if not self.should_bucket_param[i_param]: if not self.should_bucket_param[i_param]:
self.work_handles.append( last_work_handle = dist.broadcast(
Workhandle( tensor=param.data, src=global_src_rank, group=self.group, async_op=True
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
) )
i_param += 1 i_param += 1
# Bucket broadcasts # Bucket broadcasts
self.work_handles.append( last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
Workhandle(
handle=dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True),
callback=None,
)
)
self._consume_work_handles() # Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle:
last_work_handle.wait()
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment