Unverified Commit c084b202 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][SDP] privatizing all the things (#611)

parent a77c56f0
...@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module): ...@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module):
): ):
super().__init__() super().__init__()
self.module = module self._module = module
self.sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer self._sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers self._enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable self._auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16 self._reduce_fp16 = reduce_fp16
if reduce_buffer_size > 0 and reduce_fp16: if reduce_buffer_size > 0 and reduce_fp16:
self.reduce_fp16 = False self._reduce_fp16 = False
logging.warning( logging.warning(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated." "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
) )
# Handle a no_sync() context which prevents the gradient synchronization, # Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place # accumulate in place
self.should_accumulate_grads = False self._should_accumulate_grads = False
self.accumulate_grads_flipped = False self._accumulate_grads_flipped = False
# Communication related attributes # Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD self._process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group) self._backend = dist.get_backend(self._process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0 self._world_size_scaling = 1.0 / dist.get_world_size(self._process_group) # > 0
self.reference_global_rank = get_global_rank(self.process_group, 0) # picking rank 0 as the reference self._reference_global_rank = get_global_rank(self._process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group) self._rank = dist.get_rank(self._process_group)
self.global_rank = get_global_rank(self.process_group, self.rank) self._global_rank = get_global_rank(self._process_group, self._rank)
self._local_to_global_rank = [ self._local_to_global_rank = [
get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group)) get_global_rank(self._process_group, i) for i in range(dist.get_world_size(self._process_group))
] ]
# Expose some of the PytorchDDP attributes, some frameworks rely on them. # Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled # device_id related logic is not present, this is not handled
devices = {p.device for p in self.module.parameters()} devices = {p.device for p in self._module.parameters()}
self.is_multi_device_module = len(devices) > 1 self.is_multi_device_module = len(devices) > 1
self.device = list(devices)[0]
distinct_device_types = {p.device.type for p in self.module.parameters()} distinct_device_types = {p.device.type for p in self._module.parameters()}
assert len(distinct_device_types) == 1, ( assert len(distinct_device_types) == 1, (
"ShardedDataParallel's input module must be on " "ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types." "the same type of devices, but input module parameters are located on {} different device types."
...@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module): ...@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally # - we build an iterator which goes through all the parameters involved globally
self._all_params = list( self._all_params = list(
chain( chain(
*[sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self.sharded_optimizers] *[
sum([sum(p, []) for p in optim._per_device_params.values()], [])
for optim in self._sharded_optimizers
]
) )
) )
self._trainable_params: List[torch.Tensor] = [] self._trainable_params: List[torch.Tensor] = []
...@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module): ...@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module):
self._reference_trainable_mask = list(map(_trainable, self._all_params)) self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()]) model_size = sum([p.numel() for p in self._module.parameters()])
self.buffer_max_size = min(reduce_buffer_size, model_size) self._buffer_max_size = min(reduce_buffer_size, model_size)
if dist.get_world_size(self.process_group) == 1: if dist.get_world_size(self._process_group) == 1:
self.buffer_max_size = 0 self._buffer_max_size = 0
logging.info("Training is not really distributed, single rank. Deactivating buckets") logging.info("Training is not really distributed, single rank. Deactivating buckets")
logging.info( logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format( "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.buffer_max_size / 2 ** 20, model_size / 2 ** 20 self._buffer_max_size / 2 ** 20, model_size / 2 ** 20
) )
) )
self.use_buckets = self.buffer_max_size > 0 self._use_buckets = self._buffer_max_size > 0
self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {} self._buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
self._should_bucket_grad: List[bool] = [] self._should_bucket_grad: List[bool] = []
self._bucket_list: List[GradBucket] = [] self._bucket_list: List[GradBucket] = []
...@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module): ...@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module):
self._manual_reduce: List[Callable] = [] self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module) self._passing_sync_batchnorm_handle(self._module)
# Make sure that all ranks start with the same model # Make sure that all ranks start with the same model
if sync_models_at_startup: if sync_models_at_startup:
...@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module): ...@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module):
# Deferred initialization, or change detection # Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training needs_setup = len(self._grad_hooks) == 0 and self.training
if self.auto_refresh_trainable: if self._auto_refresh_trainable:
needs_setup |= self._detect_train_change() needs_setup |= self._detect_train_change()
if needs_setup: if needs_setup:
self.refresh_trainable() self.refresh_trainable()
if self.enable_broadcast_buffers: if self._enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking # NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct # for the subsequent FW to be correct
self.sync_buffers(blocking=True) self.sync_buffers(blocking=True)
...@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module): ...@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module):
self._clear_counters() self._clear_counters()
# Normal FW on the base model # Normal FW on the base model
return self.module(*inputs, **kwargs) return self._module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, self,
...@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module): ...@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module):
Module: self. Module: self.
""" """
assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state" assert device in self._buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert ( assert (
len(self.buckets.keys()) == 1 len(self._buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here" ), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys(): for _device in self._buckets.keys():
for bucket in self.buckets[_device].values(): for bucket in self._buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking) bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking) self._module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
...@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module): ...@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module):
self._trainable_params.sort(key=lambda x: x.numel()) self._trainable_params.sort(key=lambda x: x.numel())
self._trainable_param_to_rank = {} self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers: for optim in self._sharded_optimizers:
# OSS may need to change the communication pattern # OSS may need to change the communication pattern
optim.refresh_trainable() optim.refresh_trainable()
...@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module): ...@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module):
work_handles = [] work_handles = []
for buffer in self.module.buffers(recurse=True): for buffer in self._module.buffers(recurse=True):
work_handles.append( work_handles.append(
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True) dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
) )
if blocking and work_handles: if blocking and work_handles:
if self.backend != dist.Backend.NCCL: if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
else: else:
work_handles[-1].wait() work_handles[-1].wait()
...@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module): ...@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module):
try: try:
return super().__getattr__(name) # defer to nn.Module's logic return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError: except AttributeError:
return getattr(self.module, name) return getattr(self._module, name)
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization.""" """A context manager to disable gradient synchronization."""
old_should_accumulate_grads = self.should_accumulate_grads old_should_accumulate_grads = self._should_accumulate_grads
self.should_accumulate_grads = True self._should_accumulate_grads = True
yield yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads self._accumulate_grads_flipped = self._should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads self._should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad() @torch.no_grad()
def _clear_counters(self) -> None: def _clear_counters(self) -> None:
...@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module): ...@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module):
self._grad_to_be_reduced = [True for _ in self._trainable_params] self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._bucket_flush_callback_set = False self._bucket_flush_callback_set = False
if self.use_buckets: if self._use_buckets:
for bucket in self._bucket_list: for bucket in self._bucket_list:
bucket.reset_checked_in() bucket.reset_checked_in()
if not self.should_accumulate_grads: if not self._should_accumulate_grads:
self.accumulate_grads_flipped = False self._accumulate_grads_flipped = False
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable: def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
""" """
...@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module): ...@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback. Either way a delayed action is necessary and is passed as a callback.
""" """
if not self.use_buckets or not self._should_bucket_grad[index]: if not self._use_buckets or not self._should_bucket_grad[index]:
# Direct reduction # Direct reduction
@torch.no_grad() @torch.no_grad()
def reduce(*_: Any) -> None: def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set: if not self._bucket_flush_callback_set:
...@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module): ...@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice # Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling) param.grad.mul_(self._world_size_scaling)
if self.reduce_fp16: if self._reduce_fp16:
param.grad.data = param.grad.data.half() param.grad.data = param.grad.data.half()
# Future work includes clearing up the buffer if possible # Future work includes clearing up the buffer if possible
def cleanup() -> None: def cleanup() -> None:
if dst_rank != self.global_rank: if dst_rank != self._global_rank:
param.grad = None param.grad = None
else: else:
assert param.grad is not None assert param.grad is not None
...@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module): ...@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce( handle=dist.reduce(
tensor=param.grad.data, tensor=param.grad.data,
dst=self._local_to_global_rank[dst_rank], dst=self._local_to_global_rank[dst_rank],
group=self.process_group, group=self._process_group,
async_op=True, async_op=True,
), ),
callback=cleanup, callback=cleanup,
...@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module): ...@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module):
def reduce(*_: Any) -> None: def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set: if not self._bucket_flush_callback_set:
...@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module): ...@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice # Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False self._grad_to_be_reduced[index] = False
bucket = self.buckets[param.device][dst_rank] bucket = self._buckets[param.device][dst_rank]
bucket.params_checked_in += 1 bucket.params_checked_in += 1
if bucket.all_checked_in: if bucket.all_checked_in:
assert bucket.buffer is not None assert bucket.buffer is not None
# Normalize the bucket in one go # Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling) bucket.buffer.mul_(self._world_size_scaling)
# Reduce the bucket # Reduce the bucket
bucket.sent = True bucket.sent = True
...@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module): ...@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, tensor=bucket.buffer,
dst=bucket.destination, dst=bucket.destination,
group=self.process_group, group=self._process_group,
async_op=True, async_op=True,
), ),
callback=None, callback=None,
...@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module): ...@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module):
work_handles = [] work_handles = []
for t in self.module.state_dict().values(): for t in self._module.state_dict().values():
work_handles.append( work_handles.append(
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True) dist.broadcast(t, src=self._reference_global_rank, group=self._process_group, async_op=True)
) )
# gloo does not guarantee inlining like NCCL, wait for all requests # gloo does not guarantee inlining like NCCL, wait for all requests
if self.backend != dist.Backend.NCCL: if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
elif work_handles: elif work_handles:
work_handles[-1].wait() work_handles[-1].wait()
...@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module): ...@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance) This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
""" """
if not self.use_buckets: if not self._use_buckets:
return return
# Devise the bucketing strategy. Parameters are already sorted, in that: # Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads # - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size # - they are sorted by increasing size
self.buckets = {} self._buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params] self._should_bucket_grad = [False for _ in self._trainable_params]
for i, param in enumerate(self._trainable_params): for i, param in enumerate(self._trainable_params):
device = param.device device = param.device
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys(): if param.device not in self._buckets.keys():
self.buckets[param.device] = {} self._buckets[param.device] = {}
if dst_rank not in self.buckets[param.device].keys(): if dst_rank not in self._buckets[param.device].keys():
self.buckets[param.device][dst_rank] = GradBucket( self._buckets[param.device][dst_rank] = GradBucket(
self.buffer_max_size, self._buffer_max_size,
dtype=param.dtype, dtype=param.dtype,
device=param.device, device=param.device,
destination=self._local_to_global_rank[dst_rank], destination=self._local_to_global_rank[dst_rank],
...@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module): ...@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module):
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
if self.buckets[device][dst_rank].can_add_grad_view(param): if self._buckets[device][dst_rank].can_add_grad_view(param):
self.buckets[device][dst_rank].add_grad(param) self._buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True self._should_bucket_grad[i] = True
self._bucket_list = list(chain(*[self.buckets[device].values() for device in self.buckets.keys()])) self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
# Resize the buckets to remove lost space in the end # Resize the buckets to remove lost space in the end
for bucket in self._bucket_list: for bucket in self._bucket_list:
...@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module): ...@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module):
assert bucket.buffer is not None assert bucket.buffer is not None
# Normalize the bucket in one go # Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling) bucket.buffer.mul_(self._world_size_scaling)
# Reduce the bucket # Reduce the bucket
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True, tensor=bucket.buffer, dst=bucket.destination, group=self._process_group, async_op=True,
), ),
callback=None, callback=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment