self._grad_accs.append(grad_acc)# keep this function in scope
# Add a hook on the module to flush the buckets, if needed
ifself.use_buckets:
defbucket_flush(*unused:Any)->None:
handle=None
forbucket_optiminself.buckets.values():
forbucket_rankinbucket_optim.values():
forbucketinbucket_rank:
ifnotbucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent=True
handle=dist.reduce(
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
)
# Only wait on the last handle
ifhandle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad()
def_sync_params_and_buffers(self)->None:
"""
...
...
@@ -296,3 +384,54 @@ class ShardedDataParallel(nn.Module):
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer._specify_ddp_gpu_num(1)# type: ignore
def_setup_bucket_strategy(self)->None:
"""Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
"""
ifnotself.use_buckets:
return
# - Allocate one buffer per rank and per device to group the small parameters