Unverified Commit c084b202 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][SDP] privatizing all the things (#611)

parent a77c56f0
......@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module):
):
super().__init__()
self.module = module
self.sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16
self._module = module
self._sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
self._enable_broadcast_buffers = broadcast_buffers
self._auto_refresh_trainable = auto_refresh_trainable
self._reduce_fp16 = reduce_fp16
if reduce_buffer_size > 0 and reduce_fp16:
self.reduce_fp16 = False
self._reduce_fp16 = False
logging.warning(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
)
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self.should_accumulate_grads = False
self.accumulate_grads_flipped = False
self._should_accumulate_grads = False
self._accumulate_grads_flipped = False
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
self.global_rank = get_global_rank(self.process_group, self.rank)
self._process_group = process_group if process_group is not None else dist.group.WORLD
self._backend = dist.get_backend(self._process_group)
self._world_size_scaling = 1.0 / dist.get_world_size(self._process_group) # > 0
self._reference_global_rank = get_global_rank(self._process_group, 0) # picking rank 0 as the reference
self._rank = dist.get_rank(self._process_group)
self._global_rank = get_global_rank(self._process_group, self._rank)
self._local_to_global_rank = [
get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
get_global_rank(self._process_group, i) for i in range(dist.get_world_size(self._process_group))
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
devices = {p.device for p in self.module.parameters()}
devices = {p.device for p in self._module.parameters()}
self.is_multi_device_module = len(devices) > 1
self.device = list(devices)[0]
distinct_device_types = {p.device.type for p in self.module.parameters()}
distinct_device_types = {p.device.type for p in self._module.parameters()}
assert len(distinct_device_types) == 1, (
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
......@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
self._all_params = list(
chain(
*[sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self.sharded_optimizers]
*[
sum([sum(p, []) for p in optim._per_device_params.values()], [])
for optim in self._sharded_optimizers
]
)
)
self._trainable_params: List[torch.Tensor] = []
......@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module):
self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
self.buffer_max_size = min(reduce_buffer_size, model_size)
model_size = sum([p.numel() for p in self._module.parameters()])
self._buffer_max_size = min(reduce_buffer_size, model_size)
if dist.get_world_size(self.process_group) == 1:
self.buffer_max_size = 0
if dist.get_world_size(self._process_group) == 1:
self._buffer_max_size = 0
logging.info("Training is not really distributed, single rank. Deactivating buckets")
logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.buffer_max_size / 2 ** 20, model_size / 2 ** 20
self._buffer_max_size / 2 ** 20, model_size / 2 ** 20
)
)
self.use_buckets = self.buffer_max_size > 0
self._use_buckets = self._buffer_max_size > 0
self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
self._buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
self._should_bucket_grad: List[bool] = []
self._bucket_list: List[GradBucket] = []
......@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module):
self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module)
self._passing_sync_batchnorm_handle(self._module)
# Make sure that all ranks start with the same model
if sync_models_at_startup:
......@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module):
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training
if self.auto_refresh_trainable:
if self._auto_refresh_trainable:
needs_setup |= self._detect_train_change()
if needs_setup:
self.refresh_trainable()
if self.enable_broadcast_buffers:
if self._enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
......@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module):
self._clear_counters()
# Normal FW on the base model
return self.module(*inputs, **kwargs)
return self._module(*inputs, **kwargs)
def to( # type: ignore
self,
......@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module):
Module: self.
"""
assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert device in self._buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self.buckets.keys()) == 1
len(self._buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device].values():
for _device in self._buckets.keys():
for bucket in self._buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
self._module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
......@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module):
self._trainable_params.sort(key=lambda x: x.numel())
self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
for optim in self._sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
......@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module):
work_handles = []
for buffer in self.module.buffers(recurse=True):
for buffer in self._module.buffers(recurse=True):
work_handles.append(
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
)
if blocking and work_handles:
if self.backend != dist.Backend.NCCL:
if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
else:
work_handles[-1].wait()
......@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
return getattr(self._module, name)
@contextlib.contextmanager
def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization."""
old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True
old_should_accumulate_grads = self._should_accumulate_grads
self._should_accumulate_grads = True
yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads
self._accumulate_grads_flipped = self._should_accumulate_grads != old_should_accumulate_grads
self._should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad()
def _clear_counters(self) -> None:
......@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module):
self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._bucket_flush_callback_set = False
if self.use_buckets:
if self._use_buckets:
for bucket in self._bucket_list:
bucket.reset_checked_in()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
if not self._should_accumulate_grads:
self._accumulate_grads_flipped = False
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
"""
......@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
if not self.use_buckets or not self._should_bucket_grad[index]:
if not self._use_buckets or not self._should_bucket_grad[index]:
# Direct reduction
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
......@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling)
param.grad.mul_(self._world_size_scaling)
if self.reduce_fp16:
if self._reduce_fp16:
param.grad.data = param.grad.data.half()
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
if dst_rank != self._global_rank:
param.grad = None
else:
assert param.grad is not None
......@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce(
tensor=param.grad.data,
dst=self._local_to_global_rank[dst_rank],
group=self.process_group,
group=self._process_group,
async_op=True,
),
callback=cleanup,
......@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module):
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
......@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
bucket = self.buckets[param.device][dst_rank]
bucket = self._buckets[param.device][dst_rank]
bucket.params_checked_in += 1
if bucket.all_checked_in:
assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
bucket.buffer.mul_(self._world_size_scaling)
# Reduce the bucket
bucket.sent = True
......@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce(
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
group=self._process_group,
async_op=True,
),
callback=None,
......@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module):
work_handles = []
for t in self.module.state_dict().values():
for t in self._module.state_dict().values():
work_handles.append(
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
dist.broadcast(t, src=self._reference_global_rank, group=self._process_group, async_op=True)
)
# gloo does not guarantee inlining like NCCL, wait for all requests
if self.backend != dist.Backend.NCCL:
if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
elif work_handles:
work_handles[-1].wait()
......@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
if not self.use_buckets:
if not self._use_buckets:
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self.buckets = {}
self._buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params]
for i, param in enumerate(self._trainable_params):
device = param.device
dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys():
self.buckets[param.device] = {}
if param.device not in self._buckets.keys():
self._buckets[param.device] = {}
if dst_rank not in self.buckets[param.device].keys():
self.buckets[param.device][dst_rank] = GradBucket(
self.buffer_max_size,
if dst_rank not in self._buckets[param.device].keys():
self._buckets[param.device][dst_rank] = GradBucket(
self._buffer_max_size,
dtype=param.dtype,
device=param.device,
destination=self._local_to_global_rank[dst_rank],
......@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if self.buckets[device][dst_rank].can_add_grad_view(param):
self.buckets[device][dst_rank].add_grad(param)
if self._buckets[device][dst_rank].can_add_grad_view(param):
self._buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True
self._bucket_list = list(chain(*[self.buckets[device].values() for device in self.buckets.keys()]))
self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
......@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module):
assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
bucket.buffer.mul_(self._world_size_scaling)
# Reduce the bucket
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True,
),
callback=None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment