Unverified Commit 13445c55 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly (#369)

* Better unit testing
* Make it possible to refresh the DDP assumptions when the model has changed. Make it optional so that you save some time
* Enabling accumulation tests
parent 1a636557
......@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [next rel] - TBD
### Fixed
- ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369))
## [0.1.6] - 2021-02-10
### Added
......
......@@ -8,10 +8,12 @@ A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted g
reduction automatically.
"""
from collections import deque
import contextlib
import functools
from itertools import chain
import logging
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -22,6 +24,10 @@ from fairscale.optim import OSS
from fairscale.optim.utils import Bucket, Workhandle
def _trainable(param: torch.Tensor) -> bool:
return param.requires_grad
class ShardedDataParallel(nn.Module):
""" Wrap the model, and reduce the gradients to the right rank during the backward pass.
......@@ -45,9 +51,13 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
reduce_buffer_size (int):
the max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing.
auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
a parameter is frozen or unfrozen.
.. warning:
......@@ -71,6 +81,14 @@ class ShardedDataParallel(nn.Module):
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful)
or set `reduce_buffer_size` to 0
.. warning:
If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled
automatically.
If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters
for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh
ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass).
"""
def __init__(
......@@ -81,12 +99,14 @@ class ShardedDataParallel(nn.Module):
broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23,
auto_refresh_trainable: bool = True,
):
super().__init__()
self.module = module
self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
......@@ -117,14 +137,19 @@ class ShardedDataParallel(nn.Module):
# several optimizers can be present each working on seperate parameter set which is spread across multiple ranks
# - we build an iterator which goes through all the parameters involved globally
all_param_iterator = chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
self._all_params = list(
chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
)
)
self._grad_to_be_reduced = [True for _ in filter(lambda x: x.requires_grad, all_param_iterator)]
self._trainable_params: List[torch.Tensor] = []
self._grad_to_be_reduced: List[bool] = []
self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._reduced_grads = 0
self._reduced_grads_max = 0
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
......@@ -140,14 +165,12 @@ class ShardedDataParallel(nn.Module):
)
self.use_buckets = self.buffer_max_size > 0
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers}
self.buckets: Dict[torch.device, List[Bucket]] = {}
self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None
self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module)
......@@ -156,13 +179,25 @@ class ShardedDataParallel(nn.Module):
if sync_models_at_startup:
self._sync_params_and_buffers()
self._clear_counters()
self._work_handles: Deque[Workhandle] = deque()
self.refresh_trainable()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks.
"""
# Optionally check whether the trainable parameters have changed
if self.auto_refresh_trainable:
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable()
self._reference_trainable_mask = trainable_mask
if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
......@@ -205,13 +240,38 @@ class ShardedDataParallel(nn.Module):
"""
for optimizer in self.buckets.keys():
for device in self.buckets[optimizer].keys():
for bucket in self.buckets[optimizer][device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
for device in self.buckets.keys():
for bucket in self.buckets[device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device)
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), "Grads waiting to be reduced"
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._grad_to_be_reduced = [True for _ in self._trainable_params]
self._trainable_param_to_rank = {}
for optim in self.sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
# Update ShardedDDP given the new partitions
for (
device_per_rank_params
) in optim.per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim.param_to_rank[param]
self._setup_bucket_strategy()
self._setup_backward_hooks()
def reduce(self) -> None:
""".. deprecated:: 0.0.4
......@@ -223,6 +283,9 @@ class ShardedDataParallel(nn.Module):
def sync_buffers(self, blocking: bool = False) -> None:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
Arguments:
blocking (bool): wait for the operation to conclude.
"""
last_work_handle = None
......@@ -236,6 +299,21 @@ class ShardedDataParallel(nn.Module):
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle.wait()
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
under :class:`torch.optim.Optimizer` for more context.
Arguments:
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
for index, trainable_param in enumerate(self._trainable_params):
if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None
elif trainable_param.grad is not None:
trainable_param.grad.zero_()
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
......@@ -254,21 +332,20 @@ class ShardedDataParallel(nn.Module):
@torch.no_grad()
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
if not self.should_accumulate_grads:
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = 0
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
if self.use_buckets:
assert self._bucket_list is not None
# Do not reset the buckets
if self.use_buckets:
assert self._bucket_list is not None
for bucket in self._bucket_list:
assert bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
for bucket in self._bucket_list:
assert self.should_accumulate_grads or bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
bucket.reset()
bucket.reset()
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
......@@ -279,7 +356,7 @@ class ShardedDataParallel(nn.Module):
assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1)
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int, optimizer: OSS) -> Callable:
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
"""
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full.
......@@ -287,16 +364,16 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self.use_buckets or not self._should_bucket_grad[index]:
# Direct reduction
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
if not self.use_buckets or not self._should_bucket_grad[index]:
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling)
# Future work includes clearing up the buffer if possible
......@@ -305,7 +382,7 @@ class ShardedDataParallel(nn.Module):
param.grad = None
# Async reduce for this buffer, log the future
optimizer.work_handles.append(
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
......@@ -313,9 +390,28 @@ class ShardedDataParallel(nn.Module):
callback=cleanup,
)
)
self._reduced_grads[optimizer] += 1
else:
bucket = self.buckets[optimizer][param.device][dst_rank]
self._reduced_grads += 1
# Opportunistically try to empty the queue
self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
else:
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
bucket = self.buckets[param.device][dst_rank]
bucket.params_checked_in += 1
if bucket.full():
......@@ -324,7 +420,7 @@ class ShardedDataParallel(nn.Module):
# Reduce the bucket
bucket.sent = True
optimizer.work_handles.append(
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
......@@ -332,16 +428,16 @@ class ShardedDataParallel(nn.Module):
callback=None,
)
)
self._reduced_grads[optimizer] += 1
self._reduced_grads += 1
# Opportunistically try to empty the queue
optimizer._try_consume_work_handle()
# Opportunistically try to empty the queue
self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads[optimizer] == self._reduced_grads_max[optimizer]:
optimizer._consume_work_handles()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
return reduce
......@@ -352,33 +448,27 @@ class ShardedDataParallel(nn.Module):
"""
# Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers:
for (
device_per_rank_params
) in sharded_optimizer.per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
self._grad_accs = []
for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*unused: Any) -> None:
handle = None
def bucket_flush(*_: Any) -> None:
assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list:
if not bucket.sent:
......@@ -425,62 +515,80 @@ class ShardedDataParallel(nn.Module):
layer._specify_ddp_gpu_num(1) # type: ignore
def _setup_bucket_strategy(self) -> None:
"""Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
"""Devise a bucketing strategy on a per-rank ownership level.
These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
# A priori, one reduce call per param
self._reduced_grads_max = len(self._trainable_params)
if not self.use_buckets:
return
# Devise the bucketing strategy
for sharded_optimizer in self.sharded_optimizers:
for device, per_rank_params in sharded_optimizer.per_device_params.items():
self.buckets[sharded_optimizer][device] = []
for dst_rank, params in enumerate(per_rank_params):
offset = 0
self.buckets[sharded_optimizer][device].append(
Bucket(
buffer=torch.zeros(self.buffer_max_size, dtype=per_rank_params[0][0].dtype, device=device)
)
)
bucket = self.buckets[sharded_optimizer][device][dst_rank]
bucket.destination = dst_rank
for param in filter(lambda x: x.requires_grad is True, params):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (offset + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
offset_next = offset + param.numel()
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self.buckets = {}
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
for param in self._trainable_params:
device = param.device
dst_rank = self._trainable_param_to_rank[param]
param.grad.data = bucket.buffer[offset:offset_next].view_as(param.data)
offset = offset_next
# Update the bucket
self._reduced_grads_max[sharded_optimizer] -= 1 # one less reduce call per bucketed grad
self.buckets[sharded_optimizer][device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
# Resize the bucket to remove lost space in the end
bucket.buffer.resize_(offset)
if bucket.max_params_checked_in > 0:
self._reduced_grads_max[sharded_optimizer] += 1 # one reduce call per bucket
self._bucket_list = list(
chain(
*[
self.buckets[sharded_optimizer][device]
for sharded_optimizer in self.sharded_optimizers
for device in sharded_optimizer.per_device_params.keys()
if param.device not in self.buckets.keys():
self.buckets[param.device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=param.dtype, device=device))
for _ in range(dist.get_world_size(self.process_group))
]
)
)
bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (bucket.fill + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
fill_next = bucket.fill + param.numel()
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
param.grad.data = bucket.buffer[bucket.fill : fill_next].view_as(param.data)
bucket.fill = fill_next
# Update the bucket
self._reduced_grads_max -= 1 # one less reduce call per bucketed grad
self.buckets[device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
self._bucket_list = list(chain(*[self.buckets[device] for device in self.buckets.keys()]))
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill)
bucket.sent = True
if bucket.max_params_checked_in > 0:
self._reduced_grads_max += 1 # one reduce call per bucket
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self._work_handles) > 0:
work_handle = self._work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
def _try_consume_work_handle(self) -> None:
"""Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
while len(self._work_handles) > 0 and self._work_handles[0].handle.is_completed():
work_handle = self._work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
......@@ -3,19 +3,19 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict, deque
from collections import OrderedDict
import copy
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import Workhandle, broadcast_object, recursive_copy_to_device
from .utils import broadcast_object, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -52,6 +52,14 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
meaning that all the parameters which `require_grad` are handled differently. This is
not reevaluated at every step, please use `refresh_trainable()` if your model changed
(freeze or unfreeze for instance).
If used with :class:<fairscale.nn.ShardedDDP> then an automatic change detection is possible,
via the `auto_refresh_trainable` parameter.
"""
#: The optimizer used for a given shard
......@@ -81,27 +89,22 @@ class OSS(Optimizer):
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[List[torch.Tensor]] = None
# Build the wrapped optimizer, responsible for a shard of the params
# Default empty values + immutables
self._optim_defaults = default
self._optim_constructor = optim
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for key, value in local_group.items():
if key != "params":
global_group[key] = value
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu")
# Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0]
self.work_handles: Deque[Workhandle] = deque()
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._setup_flat_buffers()
# Setup everything which is related to the parameters to be trained
# (partition and optimizer for the shard)
self.refresh_trainable()
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
......@@ -277,12 +280,12 @@ class OSS(Optimizer):
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)
total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params)
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type,
)
......@@ -412,16 +415,25 @@ class OSS(Optimizer):
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed
"""
# Create the optim which will work on the param shard
if not hasattr(self, "optim"):
self._clear_cache()
self._default_device = list(self.per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
self._setup_flat_buffers()
def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)
# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._device)
dummy_sync_tensor = torch.tensor([1], device=self._default_device)
for rank in range(self.world_size):
if rank == self.rank:
......@@ -431,17 +443,20 @@ class OSS(Optimizer):
)
# legacy compatibility for old torch versions
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
self.local_state_dict(),
src_rank=self.global_rank,
group=self.group,
dist_device=self._default_device,
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
......@@ -457,19 +472,19 @@ class OSS(Optimizer):
# Sync with other replicas
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=self.global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
dist_device=self._default_device,
)
all_states.append(
......@@ -546,23 +561,6 @@ class OSS(Optimizer):
if last_work_handle:
last_work_handle.wait()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
def _try_consume_work_handle(self) -> None:
"""Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
......@@ -570,19 +568,35 @@ class OSS(Optimizer):
"""
for device, per_rank_params in self.per_device_params.items():
self.buckets[device] = []
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys():
self.buckets[device] = []
# Make parameters a view of the bucket
for dst_rank, params in enumerate(per_rank_params):
if len(params) > 0:
# Clone the non-trainable params, if in a bucket it will get destroyed
for param in filter(lambda x: not x.requires_grad, params):
param.data = param.data.detach().clone()
# Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
self.buckets[device].append(torch.empty(buffer_size, dtype=params[0].dtype, device=device))
bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
offset = 0
for param in trainable_params:
# This parameter becomes a view of the bucket
offset_next = offset + param.numel()
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
bucket[offset:offset_next].copy_(param.data.flatten())
param.data = bucket[offset:offset_next].view_as(param.data)
offset = offset_next
# Either replace the existing bucket, or create it
if len(self.buckets[device]) == dst_rank:
self.buckets[device].append(bucket)
else:
self.buckets[device][dst_rank] = bucket
else:
self.buckets[device].append(torch.zeros(1, device=device))
......@@ -3,11 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import collections
import io
from typing import Any, Callable, Dict, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
......@@ -38,7 +38,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return values if isinstance(value, list) else tuple(values)
if isinstance(value, container_abcs.Mapping):
if isinstance(value, collections.abc.Mapping):
device_val: Dict[str, Any] = {}
for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
......@@ -89,6 +89,7 @@ class Bucket:
self.max_size = buffer.numel()
# Current status for this buffer
self.fill = 0
self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1
......
......@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False
else:
return a == b
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None:
for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
......@@ -18,12 +18,13 @@ class DistributedDataParallel(Module[T_co]):
check_reduction: bool = ...
broadcast_bucket_size: float = ...
bucket_bytes_cap: float = ...
find_unused_parameters: bool = ...
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ...,
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
check_reduction: bool = ...) -> None: ...
check_reduction: bool = ..., find_unused_parameters: bool = ...) -> None: ...
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
......
......@@ -23,7 +23,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
from fairscale.utils.testing import GPT2, check_same_model_params, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
def run_one_step(rank, world_size, backend, device, temp_file_name):
......@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
NUMBER_BATCHS = 5
INPUTS = 2
BATCH_SIZE = 32
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate):
accumulate_steps = 3 if should_accumulate else 1
model.zero_grad()
def step():
if scaler is not None:
with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum()
loss.backward()
with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1):
step()
step()
def check_parity(amp: bool):
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b, atol=1e-3
), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
check_same_model_params(sharded_ddp_model, ddp_model)
# The models should stay the same in between the ranks
for i in range(10):
input_tensor = torch.rand((64, 2)).to(device)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
if ddp_scaler is not None:
with torch.cuda.amp.autocast():
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_scaler.scale(ddp_loss).backward()
else:
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
if sharded_ddp_scaler is not None:
with torch.cuda.amp.autocast():
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_ddp_scaler.scale(sharded_loss).backward()
else:
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)
# Step/scale both
if ddp_scaler is not None:
......@@ -210,13 +209,28 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
else:
sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
check_parity(amp=False)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
sharded_ddp_model.parameters()
).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Catch a version of pytorch which would not support AMP
# Test all combinations: AMP, Accumulate, Change train graph
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"):
check_parity(amp=True)
amp_tests.append(True)
for accumulate in [False, True]:
for change_train_graph in [False, True]:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp}"
)
check_parity(amp=amp, accumulate=accumulate, change_train_graph=change_train_graph)
dist.destroy_process_group()
......@@ -417,6 +431,8 @@ def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_na
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
......
......@@ -11,7 +11,7 @@
import copy
from math import inf
import tempfile
from typing import Any, Type, cast
from typing import Any, Dict, Type, cast
import unittest
import numpy as np
......@@ -22,7 +22,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
from fairscale.utils.testing import skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_py39_no_cuda, skip_if_single_gpu
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
......@@ -688,10 +688,6 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
model.zero_grad()
outputs = head(model(inputs))
def check_equal_models(message: str):
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), message
# pull the current state, broadcast it to all ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
state_dict2 = sharded_optimizer2.state_dict() if rank == RECIPIENT_RANK else {}
......@@ -701,12 +697,16 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=1e6, momentum=0.0001)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2)
check_equal_models("parameters of the two identical models have diverged (before any steps)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (before any steps)"
)
# now take a step and check that parameters are equal
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after stepping)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after stepping)"
)
# save the state dict for one model only, then distribute to the other ranks
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
......@@ -722,7 +722,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after consolidating)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after consolidating)"
)
# save again for one rank, then distribute to the others
sharded_optimizer2.consolidate_state_dict(recipient_rank=RECIPIENT_RANK) # all ranks
......@@ -737,7 +739,9 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# take a step
run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
check_equal_models("parameters of the two identical models have diverged (after reloading)")
check_same_model_params(
model_oss1, model_oss2, "parameters of the two identical models have diverged (after reloading)"
)
dist.destroy_process_group()
......@@ -768,7 +772,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
out_channels = 3
batch = 64
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False):
# Any model works. Add one different buffer per rank
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
......@@ -777,14 +781,14 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
head = torch.nn.Linear(hidden, out_channels).to(device)
# Define a model to be trained by OSS
oss_model = torch.nn.Sequential(trunk, head)
oss_module = torch.nn.Sequential(trunk, head)
oss_trainable_params = [
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-4},
]
optimizer_settings = {}
if isinstance(optim, torch.optim.SGD):
optimizer_settings: Dict[Any, Any] = {}
if isinstance(optimizer, torch.optim.SGD):
optimizer_settings["momentum"] = 0.9
sharded_optimizer = optim.OSS(
......@@ -795,7 +799,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
**optimizer_settings,
)
oss_ddp_model = DDP(module=oss_model, device_ids=[rank], broadcast_buffers=True)
oss_ddp_model = DDP(module=oss_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
# Define a model to be trained by normal pytorch + DDP
ddp_trunk = copy.deepcopy(trunk)
......@@ -807,19 +811,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
{"params": ddp_head.parameters(), "lr": 1e-4},
]
ddp_optimizer = optimizer(ddp_trainable_params, **optimizer_settings) # type: ignore
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(oss_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
ddp_model = DDP(module=ddp_module, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
def check_step():
input_tensor = torch.rand((batch, in_channels)).to(device)
......@@ -843,13 +835,21 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
check_same_model_params(oss_ddp_model, ddp_model)
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
check_same_model_params(oss_ddp_model, ddp_model)
# The models should stay the same in between ddp and sharded optimizer
for i in range(5):
check_step()
check_same_model_params()
# Check that altering the trainable parameters does not cause DDP and OSS to diverge
if change_train_graph:
# Flip the first parameter from trainable to non-trainable and vice-versa
next(ddp_module.parameters()).requires_grad = not next(ddp_module.parameters()).requires_grad
next(oss_module.parameters()).requires_grad = not next(oss_module.parameters()).requires_grad
# sharded_optimizer.refresh_trainable()
# Check that the checkpoints are compatible
# - get states
......@@ -864,10 +864,10 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# - run one step and check that the models are still the same
check_step()
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt)
for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt, change_train_graph=False)
check_optimizer_equivalence(opt, change_train_graph=True)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment