Unverified Commit 4dc605c9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP - small memory use reduction - minor speedup (#366)

* minor

* minor
parent 42e44149
...@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module): ...@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module):
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers} self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers}
self._should_bucket_grad: List[bool] = [] self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None
self._setup_bucket_strategy() self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time # - setup backward hooks which will be called by Torch's autograd in due time
...@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module): ...@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module):
if sync_models_at_startup: if sync_models_at_startup:
self._sync_params_and_buffers() self._sync_params_and_buffers()
self._clear_counters()
def forward(self, *inputs: Any, **kwargs: Any) -> Any: def forward(self, *inputs: Any, **kwargs: Any) -> Any:
""" """
Module forward pass, handles any DDP-specific work in the background. Primes the Module forward pass, handles any DDP-specific work in the background. Primes the
...@@ -256,9 +259,10 @@ class ShardedDataParallel(nn.Module): ...@@ -256,9 +259,10 @@ class ShardedDataParallel(nn.Module):
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers} self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
for optimizer in self.buckets.keys(): if self.use_buckets:
for device in self.buckets[optimizer].keys(): assert self._bucket_list is not None
for bucket in self.buckets[optimizer][device]:
for bucket in self._bucket_list:
assert bucket.sent, ( assert bucket.sent, (
"A bucket failed to be sent, probably unused parameters." "A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" + "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
...@@ -374,18 +378,15 @@ class ShardedDataParallel(nn.Module): ...@@ -374,18 +378,15 @@ class ShardedDataParallel(nn.Module):
def bucket_flush(*unused: Any) -> None: def bucket_flush(*unused: Any) -> None:
handle = None handle = None
for bucket_optim in self.buckets.values(): assert self._bucket_list is not None
for bucket_rank in bucket_optim.values():
for bucket in bucket_rank: for bucket in self._bucket_list:
if not bucket.sent: if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed # Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling) bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True bucket.sent = True
handle = dist.reduce( handle = dist.reduce(
tensor=bucket.buffer, tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
dst=bucket.destination,
group=self.process_group,
async_op=True,
) )
# Only wait on the last handle # Only wait on the last handle
...@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module): ...@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module):
if not self.use_buckets: if not self.use_buckets:
return return
# - Allocate one buffer per rank and per device to group the small parameters
for sharded_optimizer in self.sharded_optimizers:
for device, per_device in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=per_device[0][0].dtype, device=device))
for _ in per_device
]
# Devise the bucketing strategy # Devise the bucketing strategy
for sharded_optimizer in self.sharded_optimizers: for sharded_optimizer in self.sharded_optimizers:
for device, per_rank_params in sharded_optimizer.per_device_params.items(): for device, per_rank_params in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = []
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
offset = 0 offset = 0
self.buckets[sharded_optimizer][device].append(
Bucket(
buffer=torch.zeros(self.buffer_max_size, dtype=per_rank_params[0][0].dtype, device=device)
)
)
bucket = self.buckets[sharded_optimizer][device][dst_rank] bucket = self.buckets[sharded_optimizer][device][dst_rank]
bucket.destination = dst_rank bucket.destination = dst_rank
...@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module): ...@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module):
bucket.buffer.resize_(offset) bucket.buffer.resize_(offset)
if bucket.max_params_checked_in > 0: if bucket.max_params_checked_in > 0:
self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket
self._bucket_list = list(
chain(
*[
self.buckets[sharded_optimizer][device]
for sharded_optimizer in self.sharded_optimizers
for device in sharded_optimizer.per_device_params.keys()
]
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment