Unverified Commit 51625eda authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ShardedDDP] Bucketing reduce calls, tensor views (#327)

parent fa11d338
......@@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [next rel] - TBD
### Added
- Bucket calls in ShardedDDP, for faster inter node communications (#327)
- Tensor views for OSS bucketing, reduced CPU use
## [0.1.4] - 2021-01-07
### Fixed
......
......@@ -11,7 +11,7 @@ reduction automatically.
import contextlib
from itertools import chain
import logging
from typing import Any, Callable, Dict, Generator, List, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -19,7 +19,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle
from fairscale.optim.utils import Bucket, Workhandle
class ShardedDataParallel(nn.Module):
......@@ -44,6 +44,10 @@ class ShardedDataParallel(nn.Module):
sync_models_at_startup (bool):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
reduce_buffer_size (int):
the max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing.
.. warning:
......@@ -51,15 +55,22 @@ class ShardedDataParallel(nn.Module):
after the backward pass, in order to save memory and some communication bandwidth.
.. warning:
As a consequence of sharding, in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
the `optimizer state sharding wrapper <fairscale.optim.OSS>`
As a consequence of sharding:
* in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
the `optimizer state sharding wrapper <fairscale.optim.OSS>`
.. warning:
As a consequence of sharding, after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad
* after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad
* Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`,
compatible with PytorchAMP.
.. warning:
As a consequence of sharding, Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`, compatible with PytorchAMP.
ShardedDDP uses buckets to speed up the network communications. If some parameters require_grad but are not actually
used, there is a chance that this would prevent the bucket mechanism to function, and that this could not be automatically
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful)
or set `reduce_buffer_size` to 0
"""
def __init__(
......@@ -69,6 +80,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None,
broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23,
):
super().__init__()
......@@ -113,7 +125,25 @@ class ShardedDataParallel(nn.Module):
# - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._clear_counters()
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
if dist.get_world_size(self.process_group) <= 8:
logging.info("Assuming single node environment. De-activating ShardedDDP buckets")
reduce_buffer_size = 0
self.buffer_max_size = min(reduce_buffer_size, model_size)
logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.buffer_max_size / 2 ** 20, model_size / 2 ** 20
)
)
self.use_buckets = self.buffer_max_size > 0
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers}
self._should_bucket_grad: List[bool] = []
self._bucket_iterator: Optional[Iterable[Bucket]] = None
self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
......@@ -184,8 +214,19 @@ class ShardedDataParallel(nn.Module):
@torch.no_grad()
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
if not self.should_accumulate_grads:
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
for o in self.buckets.keys():
for d in self.buckets[o].keys():
for bucket in self.buckets[o][d]:
assert bucket.sent, (
"A bucket failed being sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
bucket.reset()
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
......@@ -212,23 +253,44 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling)
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
param.grad = None
# Async reduce for this buffer, log the future
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
),
callback=cleanup,
if not self.use_buckets or not self._should_bucket_grad[index]:
param.grad.mul_(self.world_size_scaling)
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
param.grad = None
# Async reduce for this buffer, log the future
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
),
callback=cleanup,
)
)
)
self._reduced_grads[optimizer] += 1
self._reduced_grads[optimizer] += 1
else:
bucket = self.buckets[optimizer][param.device][dst_rank]
bucket.params_checked_in += 1
if bucket.full():
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
bucket.sent = True
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=None,
)
)
self._reduced_grads[optimizer] += 1
# Opportunistically try to empty the queue
optimizer._try_consume_work_handle()
......@@ -268,6 +330,32 @@ class ShardedDataParallel(nn.Module):
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*unused: Any) -> None:
handle = None
for bucket_optim in self.buckets.values():
for bucket_rank in bucket_optim.values():
for bucket in bucket_rank:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
)
# Only wait on the last handle
if handle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
"""
......@@ -296,3 +384,54 @@ class ShardedDataParallel(nn.Module):
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer._specify_ddp_gpu_num(1) # type: ignore
def _setup_bucket_strategy(self) -> None:
"""Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
"""
if not self.use_buckets:
return
# - Allocate one buffer per rank and per device to group the small parameters
for sharded_optimizer in self.sharded_optimizers:
for device, per_device in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=per_device[0][0].dtype, device=device))
for _ in per_device
]
# Devise the bucketing strategy
for sharded_optimizer in self.sharded_optimizers:
for device, per_rank_params in sharded_optimizer.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket = self.buckets[sharded_optimizer][device][dst_rank]
bucket.destination = dst_rank
for param in filter(lambda x: x.requires_grad is True, params):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (offset + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
offset_next = offset + param.numel()
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
param.grad.data = bucket.buffer[offset:offset_next].view_as(param.data)
offset = offset_next
# Update the bucket
self._reduced_grads_max[sharded_optimizer] -= 1 # one less reduce call per bucketed grad
self.buckets[sharded_optimizer][device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
# Resize the bucket to remove lost space in the end
bucket.buffer.resize_(offset)
if bucket.max_params_checked_in > 0:
self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket
......@@ -76,3 +76,28 @@ def broadcast_object(
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Current status for this buffer
self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1
self.sent = True
def reset(self) -> None:
self.params_checked_in = 0
self.sent = False
def full(self) -> bool:
""" is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment