Unverified Commit 4dc605c9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP - small memory use reduction - minor speedup (#366)

* minor

* minor
parent 42e44149
......@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module):
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers}
self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None
self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time
......@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module):
if sync_models_at_startup:
self._sync_params_and_buffers()
self._clear_counters()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
......@@ -256,15 +259,16 @@ class ShardedDataParallel(nn.Module):
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
for optimizer in self.buckets.keys():
for device in self.buckets[optimizer].keys():
for bucket in self.buckets[optimizer][device]:
assert bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
if self.use_buckets:
assert self._bucket_list is not None
for bucket in self._bucket_list:
assert bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
bucket.reset()
bucket.reset()
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
......@@ -374,19 +378,16 @@ class ShardedDataParallel(nn.Module):
def bucket_flush(*unused: Any) -> None:
handle = None
for bucket_optim in self.buckets.values():
for bucket_rank in bucket_optim.values():
for bucket in bucket_rank:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
)
assert self._bucket_list is not None
for bucket in self._bucket_list:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
# Only wait on the last handle
if handle:
......@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module):
if not self.use_buckets:
return
# - Allocate one buffer per rank and per device to group the small parameters
for sharded_optimizer in self.sharded_optimizers:
for device, per_device in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=per_device[0][0].dtype, device=device))
for _ in per_device
]
# Devise the bucketing strategy
for sharded_optimizer in self.sharded_optimizers:
for device, per_rank_params in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = []
for dst_rank, params in enumerate(per_rank_params):
offset = 0
self.buckets[sharded_optimizer][device].append(
Bucket(
buffer=torch.zeros(self.buffer_max_size, dtype=per_rank_params[0][0].dtype, device=device)
)
)
bucket = self.buckets[sharded_optimizer][device][dst_rank]
bucket.destination = dst_rank
......@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module):
bucket.buffer.resize_(offset)
if bucket.max_params_checked_in > 0:
self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket
self._bucket_list = list(
chain(
*[
self.buckets[sharded_optimizer][device]
for sharded_optimizer in self.sharded_optimizers
for device in sharded_optimizer.per_device_params.keys()
]
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment