Unverified Commit 13445c55 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly (#369)

* Better unit testing
* Make it possible to refresh the DDP assumptions when the model has changed. Make it optional so that you save some time
* Enabling accumulation tests
parent 1a636557
...@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ...@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [next rel] - TBD ## [next rel] - TBD
### Fixed
- ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369))
## [0.1.6] - 2021-02-10 ## [0.1.6] - 2021-02-10
### Added ### Added
......
...@@ -8,10 +8,12 @@ A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted g ...@@ -8,10 +8,12 @@ A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted g
reduction automatically. reduction automatically.
""" """
from collections import deque
import contextlib import contextlib
import functools
from itertools import chain from itertools import chain
import logging import logging
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -22,6 +24,10 @@ from fairscale.optim import OSS ...@@ -22,6 +24,10 @@ from fairscale.optim import OSS
from fairscale.optim.utils import Bucket, Workhandle from fairscale.optim.utils import Bucket, Workhandle
def _trainable(param: torch.Tensor) -> bool:
return param.requires_grad
class ShardedDataParallel(nn.Module): class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass. """ Wrap the model, and reduce the gradients to the right rank during the backward pass.
...@@ -45,9 +51,13 @@ class ShardedDataParallel(nn.Module): ...@@ -45,9 +51,13 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state or the training restarts from a saved state
reduce_buffer_size (int): reduce_buffer_size (int):
the max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M). The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing. Set to 0 to remove all bucketing.
auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
a parameter is frozen or unfrozen.
.. warning: .. warning:
...@@ -71,6 +81,14 @@ class ShardedDataParallel(nn.Module): ...@@ -71,6 +81,14 @@ class ShardedDataParallel(nn.Module):
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful) (https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful)
or set `reduce_buffer_size` to 0 or set `reduce_buffer_size` to 0
.. warning:
If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled
automatically.
If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters
for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh
ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass).
""" """
def __init__( def __init__(
...@@ -81,12 +99,14 @@ class ShardedDataParallel(nn.Module): ...@@ -81,12 +99,14 @@ class ShardedDataParallel(nn.Module):
broadcast_buffers: bool = True, broadcast_buffers: bool = True,
sync_models_at_startup: bool = True, sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23, reduce_buffer_size: int = 2 ** 23,
auto_refresh_trainable: bool = True,
): ):
super().__init__() super().__init__()
self.module = module self.module = module
self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
# Handle a no_sync() context which prevents the gradient synchronization, # Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place # accumulate in place
...@@ -117,14 +137,19 @@ class ShardedDataParallel(nn.Module): ...@@ -117,14 +137,19 @@ class ShardedDataParallel(nn.Module):
# several optimizers can be present each working on seperate parameter set which is spread across multiple ranks # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks
# - we build an iterator which goes through all the parameters involved globally # - we build an iterator which goes through all the parameters involved globally
all_param_iterator = chain( self._all_params = list(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers] chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
)
) )
self._grad_to_be_reduced = [True for _ in filter(lambda x: x.requires_grad, all_param_iterator)] self._trainable_params: List[torch.Tensor] = []
self._grad_to_be_reduced: List[bool] = []
self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - keep track of the grads which have already been reduced # - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {} self._reduced_grads = 0
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers} self._reduced_grads_max = 0
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()]) model_size = sum([p.numel() for p in self.module.parameters()])
...@@ -140,14 +165,12 @@ class ShardedDataParallel(nn.Module): ...@@ -140,14 +165,12 @@ class ShardedDataParallel(nn.Module):
) )
self.use_buckets = self.buffer_max_size > 0 self.use_buckets = self.buffer_max_size > 0
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers} self.buckets: Dict[torch.device, List[Bucket]] = {}
self._should_bucket_grad: List[bool] = [] self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None self._bucket_list: Optional[List[Bucket]] = None
self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time # - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module) self._passing_sync_batchnorm_handle(self.module)
...@@ -156,13 +179,25 @@ class ShardedDataParallel(nn.Module): ...@@ -156,13 +179,25 @@ class ShardedDataParallel(nn.Module):
if sync_models_at_startup: if sync_models_at_startup:
self._sync_params_and_buffers() self._sync_params_and_buffers()
self._clear_counters() self._work_handles: Deque[Workhandle] = deque()
self.refresh_trainable()
def forward(self, *inputs: Any, **kwargs: Any) -> Any: def forward(self, *inputs: Any, **kwargs: Any) -> Any:
""" """
Module forward pass, handles any DDP-specific work in the background. Primes the Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks. backward pass for gradient reduction to the proper ranks.
""" """
# Optionally check whether the trainable parameters have changed
if self.auto_refresh_trainable:
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable()
self._reference_trainable_mask = trainable_mask
if self.enable_broadcast_buffers: if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking # NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct # for the subsequent FW to be correct
...@@ -205,13 +240,38 @@ class ShardedDataParallel(nn.Module): ...@@ -205,13 +240,38 @@ class ShardedDataParallel(nn.Module):
""" """
for optimizer in self.buckets.keys(): for device in self.buckets.keys():
for device in self.buckets[optimizer].keys(): for bucket in self.buckets[device]:
for bucket in self.buckets[optimizer][device]: bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device) self.module.to(device)
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), "Grads waiting to be reduced"
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
# Update ShardedDDP given the new partitions
for (
device_per_rank_params
) in optim.per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim.param_to_rank[param]
self._setup_bucket_strategy()
self._setup_backward_hooks()
def reduce(self) -> None: def reduce(self) -> None:
""".. deprecated:: 0.0.4 """.. deprecated:: 0.0.4
...@@ -223,6 +283,9 @@ class ShardedDataParallel(nn.Module): ...@@ -223,6 +283,9 @@ class ShardedDataParallel(nn.Module):
def sync_buffers(self, blocking: bool = False) -> None: def sync_buffers(self, blocking: bool = False) -> None:
""" """
Sync all the param buffers in between ranks (including for instance batch norm statistics). Sync all the param buffers in between ranks (including for instance batch norm statistics).
Arguments:
blocking (bool): wait for the operation to conclude.
""" """
last_work_handle = None last_work_handle = None
...@@ -236,6 +299,21 @@ class ShardedDataParallel(nn.Module): ...@@ -236,6 +299,21 @@ class ShardedDataParallel(nn.Module):
# Only wait for the last coms, they're inlined on the same CUDA stream # Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle.wait() last_work_handle.wait()
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
under :class:`torch.optim.Optimizer` for more context.
Arguments:
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
for index, trainable_param in enumerate(self._trainable_params):
if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None
elif trainable_param.grad is not None:
trainable_param.grad.zero_()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
try: try:
...@@ -254,21 +332,20 @@ class ShardedDataParallel(nn.Module): ...@@ -254,21 +332,20 @@ class ShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _clear_counters(self) -> None: def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters""" """Reset all the grad reduce and call counters"""
if not self.should_accumulate_grads: self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = 0
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] # Do not reset the buckets
self._reduced_grads = {o: 0 for o in self.sharded_optimizers} if self.use_buckets:
assert self._bucket_list is not None
if self.use_buckets:
assert self._bucket_list is not None
for bucket in self._bucket_list: for bucket in self._bucket_list:
assert bucket.sent, ( assert self.should_accumulate_grads or bucket.sent, (
"A bucket failed to be sent, probably unused parameters." "A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" + "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
) )
bucket.reset() bucket.reset()
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]: def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """ """ Look up where this parameter belongs to """
...@@ -279,7 +356,7 @@ class ShardedDataParallel(nn.Module): ...@@ -279,7 +356,7 @@ class ShardedDataParallel(nn.Module):
assert False, "This parameter is not present in an optimizer, this should not happen" assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1) return (None, -1)
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int, optimizer: OSS) -> Callable: def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
""" """
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank, Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full. or contribute to a bucket and reduce when the bucket is full.
...@@ -287,16 +364,16 @@ class ShardedDataParallel(nn.Module): ...@@ -287,16 +364,16 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback. Either way a delayed action is necessary and is passed as a callback.
""" """
@torch.no_grad() if not self.use_buckets or not self._should_bucket_grad[index]:
def reduce(*_: Any) -> None: # Direct reduction
# Skip gradient reduction, do not alter status flags @torch.no_grad()
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: def reduce(*_: Any) -> None:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice # Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False self._grad_to_be_reduced[index] = False
if not self.use_buckets or not self._should_bucket_grad[index]:
param.grad.mul_(self.world_size_scaling) param.grad.mul_(self.world_size_scaling)
# Future work includes clearing up the buffer if possible # Future work includes clearing up the buffer if possible
...@@ -305,7 +382,7 @@ class ShardedDataParallel(nn.Module): ...@@ -305,7 +382,7 @@ class ShardedDataParallel(nn.Module):
param.grad = None param.grad = None
# Async reduce for this buffer, log the future # Async reduce for this buffer, log the future
optimizer.work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
...@@ -313,9 +390,28 @@ class ShardedDataParallel(nn.Module): ...@@ -313,9 +390,28 @@ class ShardedDataParallel(nn.Module):
callback=cleanup, callback=cleanup,
) )
) )
self._reduced_grads[optimizer] += 1 self._reduced_grads += 1
else:
bucket = self.buckets[optimizer][param.device][dst_rank] # Opportunistically try to empty the queue
self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
else:
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
bucket = self.buckets[param.device][dst_rank]
bucket.params_checked_in += 1 bucket.params_checked_in += 1
if bucket.full(): if bucket.full():
...@@ -324,7 +420,7 @@ class ShardedDataParallel(nn.Module): ...@@ -324,7 +420,7 @@ class ShardedDataParallel(nn.Module):
# Reduce the bucket # Reduce the bucket
bucket.sent = True bucket.sent = True
optimizer.work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True, tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
...@@ -332,16 +428,16 @@ class ShardedDataParallel(nn.Module): ...@@ -332,16 +428,16 @@ class ShardedDataParallel(nn.Module):
callback=None, callback=None,
) )
) )
self._reduced_grads[optimizer] += 1 self._reduced_grads += 1
# Opportunistically try to empty the queue # Opportunistically try to empty the queue
optimizer._try_consume_work_handle() self._try_consume_work_handle()
# If all the reduce operations have been called, # If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on # make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets) # and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads[optimizer] == self._reduced_grads_max[optimizer]: if self._reduced_grads == self._reduced_grads_max:
optimizer._consume_work_handles() self._consume_work_handles()
return reduce return reduce
...@@ -352,33 +448,27 @@ class ShardedDataParallel(nn.Module): ...@@ -352,33 +448,27 @@ class ShardedDataParallel(nn.Module):
""" """
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers: self._grad_accs = []
for ( for index, param in enumerate(self._trainable_params):
device_per_rank_params if param.grad is not None and param.grad.requires_grad:
) in sharded_optimizer.per_device_params.values(): # all the params on this device (inc all ranks) raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line, # Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed # so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param) p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer)) grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed # Add a hook on the module to flush the buckets, if needed
if self.use_buckets: if self.use_buckets:
def bucket_flush(*unused: Any) -> None: def bucket_flush(*_: Any) -> None:
handle = None
assert self._bucket_list is not None assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list: for bucket in self._bucket_list:
if not bucket.sent: if not bucket.sent:
...@@ -425,62 +515,80 @@ class ShardedDataParallel(nn.Module): ...@@ -425,62 +515,80 @@ class ShardedDataParallel(nn.Module):
layer._specify_ddp_gpu_num(1) # type: ignore layer._specify_ddp_gpu_num(1) # type: ignore
def _setup_bucket_strategy(self) -> None: def _setup_bucket_strategy(self) -> None:
"""Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case. """Devise a bucketing strategy on a per-rank ownership level.
These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
""" """
# A priori, one reduce call per param
self._reduced_grads_max = len(self._trainable_params)
if not self.use_buckets: if not self.use_buckets:
return return
# Devise the bucketing strategy # Devise the bucketing strategy. Parameters are already sorted, in that:
for sharded_optimizer in self.sharded_optimizers: # - these are only the trainable parameters, so they should produce grads
for device, per_rank_params in sharded_optimizer.per_device_params.items(): # - they are sorted by increasing size
self.buckets[sharded_optimizer][device] = [] self.buckets = {}
for dst_rank, params in enumerate(per_rank_params):
offset = 0
self.buckets[sharded_optimizer][device].append(
Bucket(
buffer=torch.zeros(self.buffer_max_size, dtype=per_rank_params[0][0].dtype, device=device)
)
)
bucket = self.buckets[sharded_optimizer][device][dst_rank]
bucket.destination = dst_rank
for param in filter(lambda x: x.requires_grad is True, params):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (offset + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
offset_next = offset + param.numel()
if param.grad is None: for param in self._trainable_params:
# will be overwritten just below, see next line device = param.device
param.grad = torch.zeros_like(param) dst_rank = self._trainable_param_to_rank[param]
param.grad.data = bucket.buffer[offset:offset_next].view_as(param.data) if param.device not in self.buckets.keys():
offset = offset_next self.buckets[param.device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=param.dtype, device=device))
# Update the bucket for _ in range(dist.get_world_size(self.process_group))
self._reduced_grads_max[sharded_optimizer] -= 1 # one less reduce call per bucketed grad
self.buckets[sharded_optimizer][device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
# Resize the bucket to remove lost space in the end
bucket.buffer.resize_(offset)
if bucket.max_params_checked_in > 0:
self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket
self._bucket_list = list(
chain(
*[
self.buckets[sharded_optimizer][device]
for sharded_optimizer in self.sharded_optimizers
for device in sharded_optimizer.per_device_params.keys()
] ]
)
) bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (bucket.fill + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
fill_next = bucket.fill + param.numel()
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
param.grad.data = bucket.buffer[bucket.fill : fill_next].view_as(param.data)
bucket.fill = fill_next
# Update the bucket
self._reduced_grads_max -= 1 # one less reduce call per bucketed grad
self.buckets[device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
self._bucket_list = list(chain(*[self.buckets[device] for device in self.buckets.keys()]))
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill)
bucket.sent = True
if bucket.max_params_checked_in > 0:
self._reduced_grads_max += 1 # one reduce call per bucket
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self._work_handles) > 0:
work_handle = self._work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
def _try_consume_work_handle(self) -> None:
"""Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
while len(self._work_handles) > 0 and self._work_handles[0].handle.is_completed():
work_handle = self._work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
...@@ -3,19 +3,19 @@ ...@@ -3,19 +3,19 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict, deque from collections import OrderedDict
import copy import copy
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import Workhandle, broadcast_object, recursive_copy_to_device from .utils import broadcast_object, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -52,6 +52,14 @@ class OSS(Optimizer): ...@@ -52,6 +52,14 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD) torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int): broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore. (deprecated) used to cap the size of the broadcast buffers, not being used anymore.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
meaning that all the parameters which `require_grad` are handled differently. This is
not reevaluated at every step, please use `refresh_trainable()` if your model changed
(freeze or unfreeze for instance).
If used with :class:<fairscale.nn.ShardedDDP> then an automatic change detection is possible,
via the `auto_refresh_trainable` parameter.
""" """
#: The optimizer used for a given shard #: The optimizer used for a given shard
...@@ -81,27 +89,22 @@ class OSS(Optimizer): ...@@ -81,27 +89,22 @@ class OSS(Optimizer):
self._param_to_index: Dict[int, int] = {} self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None self._local_params: Optional[List[torch.Tensor]] = None
# Build the wrapped optimizer, responsible for a shard of the params # Default empty values + immutables
self._optim_defaults = default
self._optim_constructor = optim
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default) self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for key, value in local_group.items():
if key != "params":
global_group[key] = value
# Optional consolidated optimizer state self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = [] self._default_device = torch.device("cpu")
# Current default device is set by the parameters allocated to this rank # Setup everything which is related to the parameters to be trained
self._device = list(self.per_device_params.keys())[0] # (partition and optimizer for the shard)
self.work_handles: Deque[Workhandle] = deque() self.refresh_trainable()
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._setup_flat_buffers()
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
...@@ -277,12 +280,12 @@ class OSS(Optimizer): ...@@ -277,12 +280,12 @@ class OSS(Optimizer):
# Compute the norm on this grad set, # Compute the norm on this grad set,
# then sync all the norms from all ranks # then sync all the norms from all ranks
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params)
# all reduce over data parallel and model parallel workers # all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else: else:
local_norm = torch.norm( local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type, p=norm_type,
) )
...@@ -412,16 +415,25 @@ class OSS(Optimizer): ...@@ -412,16 +415,25 @@ class OSS(Optimizer):
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups) OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed
"""
# Create the optim which will work on the param shard
if not hasattr(self, "optim"):
self._clear_cache()
self._default_device = list(self.per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
self._setup_flat_buffers()
def _broadcast_state_dict(self) -> None: def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others""" """Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)
# Tensor cannot be really empty, even if its size is meaningless # Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._device) dummy_sync_tensor = torch.tensor([1], device=self._default_device)
for rank in range(self.world_size): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
...@@ -431,17 +443,20 @@ class OSS(Optimizer): ...@@ -431,17 +443,20 @@ class OSS(Optimizer):
) )
# legacy compatibility for old torch versions # legacy compatibility for old torch versions
broadcast_object( broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device self.local_state_dict(),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
) )
else: else:
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object( broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device), torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank, src_rank=global_rank,
group=self.group, group=self.group,
dist_device=self._device, dist_device=self._default_device,
) )
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
...@@ -457,19 +472,19 @@ class OSS(Optimizer): ...@@ -457,19 +472,19 @@ class OSS(Optimizer):
# Sync with other replicas # Sync with other replicas
broadcast_object( broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device), torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self.global_rank, src_rank=self.global_rank,
group=self.group, group=self.group,
dist_device=self._device, dist_device=self._default_device,
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object( replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device), torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank, src_rank=global_rank,
group=self.group, group=self.group,
dist_device=self._device, dist_device=self._default_device,
) )
all_states.append( all_states.append(
...@@ -546,23 +561,6 @@ class OSS(Optimizer): ...@@ -546,23 +561,6 @@ class OSS(Optimizer):
if last_work_handle: if last_work_handle:
last_work_handle.wait() last_work_handle.wait()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
def _try_consume_work_handle(self) -> None:
"""Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer. """Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
...@@ -570,19 +568,35 @@ class OSS(Optimizer): ...@@ -570,19 +568,35 @@ class OSS(Optimizer):
""" """
for device, per_rank_params in self.per_device_params.items(): for device, per_rank_params in self.per_device_params.items():
self.buckets[device] = [] # Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys():
self.buckets[device] = []
# Make parameters a view of the bucket
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
if len(params) > 0: if len(params) > 0:
# Clone the non-trainable params, if in a bucket it will get destroyed
for param in filter(lambda x: not x.requires_grad, params):
param.data = param.data.detach().clone()
# Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params)) trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params)) buffer_size = sum(map(lambda x: x.numel(), trainable_params))
self.buckets[device].append(torch.empty(buffer_size, dtype=params[0].dtype, device=device)) bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
offset = 0 offset = 0
for param in trainable_params: for param in trainable_params:
# This parameter becomes a view of the bucket
offset_next = offset + param.numel() offset_next = offset + param.numel()
bucket[offset:offset_next].copy_(param.data.flatten())
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten()) param.data = bucket[offset:offset_next].view_as(param.data)
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next offset = offset_next
# Either replace the existing bucket, or create it
if len(self.buckets[device]) == dst_rank:
self.buckets[device].append(bucket)
else:
self.buckets[device][dst_rank] = bucket
else:
self.buckets[device].append(torch.zeros(1, device=device))
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import collections
import io import io
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
import torch import torch
from torch._six import container_abcs
import torch.distributed as dist import torch.distributed as dist
...@@ -38,7 +38,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic ...@@ -38,7 +38,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return values if isinstance(value, list) else tuple(values) return values if isinstance(value, list) else tuple(values)
if isinstance(value, container_abcs.Mapping): if isinstance(value, collections.abc.Mapping):
device_val: Dict[str, Any] = {} device_val: Dict[str, Any] = {}
for key, val in value.items(): for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device) device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
...@@ -89,6 +89,7 @@ class Bucket: ...@@ -89,6 +89,7 @@ class Bucket:
self.max_size = buffer.numel() self.max_size = buffer.numel()
# Current status for this buffer # Current status for this buffer
self.fill = 0
self.params_checked_in = 0 self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1 self.destination = -1
......
...@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: ...@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False return False
else: else:
return a == b return a == b
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
...@@ -18,12 +18,13 @@ class DistributedDataParallel(Module[T_co]): ...@@ -18,12 +18,13 @@ class DistributedDataParallel(Module[T_co]):
check_reduction: bool = ... check_reduction: bool = ...
broadcast_bucket_size: float = ... broadcast_bucket_size: float = ...
bucket_bytes_cap: float = ... bucket_bytes_cap: float = ...
find_unused_parameters: bool = ...
# TODO type process_group once `distributed` module is stubbed # TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ..., output_device: Optional[_device_t] = ..., dim: int = ...,
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ..., broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
check_reduction: bool = ...) -> None: ... check_reduction: bool = ..., find_unused_parameters: bool = ...) -> None: ...
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ... def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
......
...@@ -23,7 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -23,7 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu from fairscale.utils.testing import GPT2, check_same_model_params, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
def run_one_step(rank, world_size, backend, device, temp_file_name): def run_one_step(rank, world_size, backend, device, temp_file_name):
...@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
NUMBER_BATCHS = 5
INPUTS = 2
BATCH_SIZE = 32
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate):
accumulate_steps = 3 if should_accumulate else 1
model.zero_grad()
def step():
if scaler is not None:
with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum()
loss.backward()
with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1):
step()
step()
def check_parity(amp: bool):
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) # Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
sharded_ddp_model = ShardedDataParallel( sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
) )
ddp_model_single = copy.deepcopy(model) ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
ddp_scaler = TorchGradScaler() if amp else None ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None sharded_ddp_scaler = ShardedGradScaler() if amp else None
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b, atol=1e-3
), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"
# The model should be synchronized in between the ranks at construction time, check that # The model should be synchronized in between the ranks at construction time, check that
check_same_model_params() check_same_model_params(sharded_ddp_model, ddp_model)
# The models should stay the same in between the ranks # Typical training loop, check that we get the exact same results as DDP
for i in range(10): for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)
def closure_ddp(input_tensor=input_tensor): def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad() return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
if ddp_scaler is not None:
with torch.cuda.amp.autocast():
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_scaler.scale(ddp_loss).backward()
else:
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor): def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad() return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)
if sharded_ddp_scaler is not None:
with torch.cuda.amp.autocast():
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_ddp_scaler.scale(sharded_loss).backward()
else:
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
# Step/scale both # Step/scale both
if ddp_scaler is not None: if ddp_scaler is not None:
...@@ -210,13 +209,28 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -210,13 +209,28 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
else: else:
sharded_optimizer.step(closure=closure_sharded) sharded_optimizer.step(closure=closure_sharded)
check_same_model_params() check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
check_parity(amp=False) # Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
sharded_ddp_model.parameters()
).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Catch a version of pytorch which would not support AMP # Test all combinations: AMP, Accumulate, Change train graph
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"): if hasattr(torch.cuda.amp, "autocast"):
check_parity(amp=True) amp_tests.append(True)
for accumulate in [False, True]:
for change_train_graph in [False, True]:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp}"
)
check_parity(amp=amp, accumulate=accumulate, change_train_graph=change_train_graph)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -417,6 +431,8 @@ def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_na ...@@ -417,6 +431,8 @@ def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_na
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device) model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
import copy import copy
from math import inf from math import inf
import tempfile import tempfile
from typing import Any, Type, cast from typing import Any, Dict, Type, cast
import unittest import unittest
import numpy as np import numpy as np
...@@ -22,7 +22,7 @@ import torch.multiprocessing as mp ...@@ -22,7 +22,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
from fairscale.utils.testing import skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
...@@ -688,10 +688,6 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -688,10 +688,6 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
model.zero_grad() model.zero_grad()
outputs = head(model(inputs)) outputs = head(model(inputs))
def check_equal_models(message: str):
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), message
# pull the current state, broadcast it to all ranks # pull the current state, broadcast it to all ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {} state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
...@@ -701,12 +697,16 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -701,12 +697,16 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001) sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()}) sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2) sharded_optimizer2.load_state_dict(state_dict2)
check_equal_models("parameters of the two identical models have diverged (before any steps)") check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (before any steps)"
)
# now take a step and check that parameters are equal # now take a step and check that parameters are equal
run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after stepping)") check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after stepping)"
)
# save the state dict for one model only, then distribute to the other ranks # save the state dict for one model only, then distribute to the other ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
...@@ -722,7 +722,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -722,7 +722,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step # take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after consolidating)") check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after consolidating)"
)
# save again for one rank, then distribute to the others # save again for one rank, then distribute to the others
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
...@@ -737,7 +739,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -737,7 +739,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step # take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1) run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2) run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after reloading)") check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after reloading)"
)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -768,7 +772,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -768,7 +772,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
out_channels = 3 out_channels = 3
batch = 64 batch = 64
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]): def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False):
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden)) trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
trunk.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
...@@ -777,14 +781,14 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -777,14 +781,14 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
head = torch.nn.Linear(hidden, out_channels).to(device) head = torch.nn.Linear(hidden, out_channels).to(device)
# Define a model to be trained by OSS # Define a model to be trained by OSS
oss_model = torch.nn.Sequential(trunk, head) oss_module = torch.nn.Sequential(trunk, head)
oss_trainable_params = [ oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5}, {"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4}, {"params": head.parameters(), "lr": 1e-4},
] ]
optimizer_settings = {} optimizer_settings: Dict[Any, Any] = {}
if isinstance(optim, torch.optim.SGD): if isinstance(optimizer, torch.optim.SGD):
optimizer_settings["momentum"] = 0.9 optimizer_settings["momentum"] = 0.9
sharded_optimizer = optim.OSS( sharded_optimizer = optim.OSS(
...@@ -795,7 +799,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -795,7 +799,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
**optimizer_settings, **optimizer_settings,
) )
oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True) oss_ddp_model = DDP(module=oss_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
# Define a model to be trained by normal pytorch + DDP # Define a model to be trained by normal pytorch + DDP
ddp_trunk = copy.deepcopy(trunk) ddp_trunk = copy.deepcopy(trunk)
...@@ -807,19 +811,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -807,19 +811,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
{"params": ddp_head.parameters(), "lr": 1e-4}, {"params": ddp_head.parameters(), "lr": 1e-4},
] ]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True) ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
def check_step(): def check_step():
input_tensor = torch.rand((batch, in_channels)).to(device) input_tensor = torch.rand((batch, in_channels)).to(device)
...@@ -843,13 +835,21 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -843,13 +835,21 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_ddp, loss_sharded_optim loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
check_same_model_params(oss_ddp_model, ddp_model)
# The model should be synchronized in between the ranks at construction time, check that # The model should be synchronized in between the ranks at construction time, check that
check_same_model_params() check_same_model_params(oss_ddp_model, ddp_model)
# The models should stay the same in between ddp and sharded optimizer # The models should stay the same in between ddp and sharded optimizer
for i in range(5): for i in range(5):
check_step() check_step()
check_same_model_params()
# Check that altering the trainable parameters does not cause DDP and OSS to diverge
if change_train_graph:
# Flip the first parameter from trainable to non-trainable and vice-versa
next(ddp_module.parameters()).requires_grad = not next(ddp_module.parameters()).requires_grad
next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible # Check that the checkpoints are compatible
# - get states # - get states
...@@ -864,10 +864,10 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -864,10 +864,10 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# - run one step and check that the models are still the same # - run one step and check that the models are still the same
check_step() check_step()
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]: for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt) check_optimizer_equivalence(opt, change_train_graph=False)
check_optimizer_equivalence(opt, change_train_graph=True)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment