Unverified Commit dd441e9d authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP & OSS, small improvements (#321)

* Couple of small improvements, no logic changes
parent bd5d0496
......@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
"""
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
@torch.no_grad()
def sync_buffers(self, blocking: bool = False) -> None:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
"""
with torch.no_grad():
work_handles = [
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
for buffer in self.module.buffers(recurse=True)
]
if blocking:
_ = list(map(lambda x: x.wait(), work_handles))
last_work_handle = None
for buffer in self.module.buffers(recurse=True):
last_work_handle = dist.broadcast(
buffer.data, self.reference_global_rank, self.process_group, async_op=True
)
if blocking and last_work_handle:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle.wait()
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
......@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
yield
self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad()
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
......@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
......@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles))
last_work_handle = None
for t in self.module.state_dict().values():
last_work_handle = dist.broadcast(
t, src=self.reference_global_rank, group=self.process_group, async_op=True
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if last_work_handle:
last_work_handle.wait()
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
"""
......
......@@ -534,6 +534,7 @@ class OSS(Optimizer):
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
@torch.no_grad()
def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
......@@ -548,10 +549,12 @@ class OSS(Optimizer):
elif k in global_group.keys():
local_group[k] = global_group[k]
@torch.no_grad()
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
i_param = 0
last_work_handle = None # Work handles are consumed within this scope, no callback
for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device]
......@@ -562,25 +565,18 @@ class OSS(Optimizer):
# Direct broadcasts only
for param in params:
if not self.should_bucket_param[i_param]:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
last_work_handle = dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
)
i_param += 1
# Bucket broadcasts
self.work_handles.append(
Workhandle(
handle=dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True),
callback=None,
)
)
last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
self._consume_work_handles()
# Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle:
last_work_handle.wait()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment