Unverified Commit 15512d9e authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Add FullyShardedDataParallel (FSDP) (#413)

Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and [Google](https://arxiv.org/abs/2004.13336

) has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper, which is a drop-in replacement for PyTorch's `DistributedDataParallel` (DDP) wrapper.

Compared to PyTorch DDP:
* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
* FSDP with `reshard_after_forward=False` has the same communication cost as PyTorch DDP and is similar to ZeRO-2
* FSDP with `reshard_after_forward=True` increases total communication by 50% and is similar to ZeRO-3:
    * all-gather parameters at start of forward pass and start of backward pass
    * reduce-scatter grads at end of backward pass
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 279b8024
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .fully_sharded_data_parallel import FullyShardedDataParallel
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
from enum import Enum, auto
import functools
from math import inf
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import (
apply_to_tensors,
pack_kwargs,
split_non_tensors,
unpack_kwargs,
unpack_non_tensors,
)
from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
class TrainingState(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state.
TODO (Min): It would be nice to capture the stepping state as well.
Maybe we can use the model.zero_grad() call, but not sure if it
is called if optim.zero_grad() is used instead.
It would be nice to have clear state transition be explicit like:
zero_grad -> fwd -> bwd -> optionally accum grad by repeating
fwd/bwd -> stepping -> loop back to zero_grad
"""
IDLE = auto()
FORWARD = auto()
BACKWARD = auto()
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Usage::
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce memory usage and to improve training speed by distributing the
unsharding (all-gather) across the forward pass. For example::
sharded_model = FullyShardedDataParallel(
nn.Sequential(
nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
nn.Linear(100, 5),
)
)
Args:
module (nn.Module): module to checkpoint
process_group (Optional): process group for sharding
reshard_after_forward (bool, Optional): if ``True``, reshard parameters
after the forward pass. This saves memory but slows training. This
is only relevant when resharding individual layers.
mixed_precision (bool, Optional): if ``True``, inputs, activations and
gradients will be kept in FP16; computation and communication will
occur in FP16; and a (sharded) master copy of the model weights will
be maintained in FP32.
fp32_reduce_scatter (bool, Optional): if ``True``, then reduce-scatter
gradients in FP32. This is only relevant when *``mixed_precision``*
is ``True``.
flatten_parameters (bool, Optional): if ``True``, flatten parameters
into a single contiguous tensor, which improves training speed.
cpu_offload (bool, Optional): if ``True``, offload FP32 params to CPU.
This is only relevant when *``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional): dtype for full parameters for
computation. This defaults to ``torch.float32`` unless
*``mixed_precision``* is set, in which case it defaults to
``torch.float16``.
move_grads_to_cpu (bool, Optional): move gradient shard to CPU after
reduction. This is useful when combined with CPU-based optimizers.
It defaults to the value of *``cpu_offload``*.
bucket_cap_mb (int, Optional): FSDP will bucket parameters so that
gradient reduction can potentially overlap with backward
computation. bucket_cap_mb controls the bucket size in MegaBytes
(MB). Buckets are sub-divided based on world_size, so the max shard
size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing. Default: 25.
"""
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
cpu_offload: bool = False,
compute_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
):
super().__init__()
self.process_group = process_group or dist.new_group()
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
self.reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group)
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
# shard any leftover parameters.
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
if self.flatten_parameters and len(params) > 0:
self.module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.params = [self.module.flat_param]
else:
self.module = module
self.params = params
# Shard module parameters in place
self._shard_parameters_()
# Make sure all parameters are sharded.
for n, p in self.named_parameters():
assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"
self._reset_lazy_init()
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self.require_backward_grad_sync: bool = True
self.training_state = TrainingState.IDLE
@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None] """
return [p for p in self.parameters() if p.grad is not None]
@torch.no_grad()
def clip_grad_norm_(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
# filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Arguments:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
"""
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE
max_norm = float(max_norm)
norm_type = float(norm_type)
params_with_grad = self.params_with_grad
if not self.children_share_process_group:
raise NotImplementedError(
"clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
)
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
if norm_type == inf:
total_norm = local_norm
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
else:
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.process_group)
total_norm = total_norm ** (1.0 / norm_type)
if self.move_grads_to_cpu:
total_norm = total_norm.cpu()
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
# multiply by clip_coef
for p in params_with_grad:
p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore
return total_norm
@torch.no_grad()
def _shard_parameters_(self) -> None:
"""
At initialization we wrap a module with full parameters and shard the
parameters in-place. Sharding is implemented by viewing each parameter
as a 1D Tensor and retaining only a single slice, where the slice size
is determined by the number of data parallel workers.
Wrapping modules with many small parameters (or with a very large data
parallel world size) will result in many small parameter shards and slow
performance. In this case it's better to set *``flatten_parameters``* to
``True``, so that all of the small parameters in the module are combined
into a single contiguous Tensor and sharded once.
After this initial sharding is complete, the user can initialize a
``torch.optim.Optimizer`` in the usual way, i.e.::
.. code-block:: python
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
The optimizer will see only a single slice of parameters and will thus
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
if self.mixed_precision:
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._orig_size = p.data.size()
if not p._is_sharded:
continue
p._is_sharded = True
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(p.data).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
# Replace p.data with the relevant shard.
orig_data = p.data
p.data = chunks[self.rank].clone() # clone since we free storage below
if num_to_pad > 0:
p.data = F.pad(p.data, [0, num_to_pad])
free_storage_(orig_data)
def extra_repr(self) -> str:
return (
f"rank={self.rank}, world_size={self.world_size}, "
f"reshard_after_forward={self.reshard_after_forward}, "
f"mixed_precision={self.mixed_precision}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}"
)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def __getstate__(self) -> Dict[str, str]:
"""Serialize the state of the current FullyShardedDataParallel instance.
Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`.
"""
state = copy.copy(self.__dict__)
state["is_sharded"] = [p._is_sharded for p in self.params]
state["orig_sizes"] = [p._orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # process_group isn't pickleable
self._reset_lazy_init()
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Intercept state setting and perform needed changes on params."""
super().__setstate__(state)
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
p._is_sharded = is_sharded
p._orig_size = size
return p
self.params = [
fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
]
del self.is_sharded
del self.orig_sizes
self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function.
def state_dict(self, *args, **kwargs): # type: ignore
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32
"""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters.
state_dict = self.module.state_dict(*args, **kwargs)
# We don't free the params after generating the state dict, since
# freeing is done in-place (via the Storage) and would corrupt the
# returned state dict. However, we need to maintain the invariant that
# p.data corresponds to the FP32 param shard, so we do that here.
self._use_fp32_param_shard()
self._all_buffers_to(dtype=self.compute_dtype)
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
def local_state_dict(self, *args, **kwargs): # type: ignore
"""
Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
"""
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters:
return self.module.flat_state_dict(*args, **kwargs) # type: ignore
else:
return self.module.state_dict(*args, **kwargs)
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a whole (unsharded) state_dict."""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
output = self.module.load_state_dict(state_dict, strict)
self._free_full_params()
return output
def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a local (sharded) state_dict."""
torch.cuda.synchronize()
return self.module.load_state_dict(state_dict, strict)
@contextlib.contextmanager
def no_sync(self) -> Generator:
"""
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass exiting the context.
"""
self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported"
self.assert_state(TrainingState.IDLE)
# This instance may wrap other FullyShardedDataParallel instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m.require_backward_grad_sync))
m.require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
m.require_backward_grad_sync = old_flag
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
before the first forward pass."""
# Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__.
for p in self.params:
self._init_param_attributes(p)
# Initialize _is_root and setup streams. These steps would ideally
# happen in __init__, but _is_root can only be determined after the
# entire model hierarchy is setup, thus we run it lazily.
if self._is_root is None:
self._set_is_root()
self._setup_streams()
if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype)
else:
self._all_buffers_to(dtype=self.compute_dtype)
if self._is_root:
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False
# Due to the use of streams, we need to make sure the previous
# ``optim.step()`` is done before we all-gather parameters.
self._wait_for_previous_optim_step()
@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
"""
We manage several attributes on each Parameter instance. The first two
are set by :func:`_shard_parameters_`:
``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``_orig_size``: the size of the original Parameter (before sharding)
The remaining attributes are set here:
``_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``cpu_offload``*.
``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
a single shard of the parameters in FP16, used for all-gather.
``_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
"""
assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
if hasattr(p, "_fp32_shard"):
return
# Compute device defaults to CUDA when *cpu_offload* is enabled, or the
# param's current device otherwise (could be CPU).
compute_device = torch.device("cuda") if self.cpu_offload else p.device
# A single shard of the parameters in full precision.
p._fp32_shard = p.data
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.cpu_offload:
assert p._fp32_shard.device == torch.device("cpu")
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)
if self.move_grads_to_cpu:
# We can optionally move the grad shard to CPU during the backward
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child instances share the same process group.
If some child instances use a different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None:
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
self._is_root = True
# As the root, we now set all children instances to False.
self.children_share_process_group = True
for n, m in self.named_modules():
if n != "" and isinstance(m, FullyShardedDataParallel):
assert m._is_root is None
m._is_root = False
if m.process_group != self.process_group:
self.children_share_process_group = False
def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
return
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams["fp32_to_fp16"] = torch.cuda.Stream()
# Stream for all-gathering parameters.
self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream()
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
# We share streams with all children instances, which allows them to
# overlap transfers across the forward pass without synchronizing with
# the default stream.
for n, m in self.named_modules():
if n != "" and isinstance(m, FullyShardedDataParallel):
m._streams = self._streams
m._reducer = self._reducer
def _wait_for_previous_optim_step(self) -> None:
"""
The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
if self.mixed_precision:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._lazy_init()
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
if self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
outputs = self.module(*args, **kwargs)
if self.reshard_after_forward:
self._free_full_params()
# Switch to main FP32 param shard. We maintain this invariant throughout
# the code, i.e., ``p.data == p._fp32_shard`` after each function. This
# also ensures that after the first forward, the optimizer state will be
# initialized with the correct dtype and (sharded) size, since optimizer
# state is typically initialized lazily in ``optim.step()``.
self._use_fp32_param_shard()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if needed).
outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass.
self.training_state = TrainingState.IDLE
return outputs
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward."""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
pre_backward_hook_has_run = [False]
def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True
# Start of a backward pass.
self.training_state = TrainingState.BACKWARD
# All-gather full parameters.
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
# Make sure p.grad has the correct size/device (or set it to None).
self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor:
t.register_hook(_pre_backward_hook)
return t
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)
return outputs
def _register_post_backward_hooks(self) -> None:
"""Register backward hooks to reshard params and reduce-scatter grads."""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
self._post_backward_callback_queued = False
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
p._shard_bwd_hook[1].remove() # remove existing handle
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
"""
At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will replace
``param.grad`` with a single shard of the summed gradient across all
GPUs. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by :func:`_shard_parameters_`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
self.assert_state(TrainingState.BACKWARD)
if param.grad is None:
return
if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
# Free full params and switch to FP32 shard after backward.
self._free_full_params([param])
self._use_fp32_param_shard([param])
if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self._free_fp16_param_shard([param])
# Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and it only
# needs to be called from the outer-most (root) instance.
if self._is_root and not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)
if not self.require_backward_grad_sync:
return
# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
if self.world_size > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)
callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded:
assert param._is_sharded
assert self._reducer is not None
grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
callback_fn(param.grad.data)
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
# further reuse by the main stream while the div/reduce_scatter/copy
# are underway in the post_backward stream. See:
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
orig_grad_data.record_stream(self._streams["post_backward"])
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD)
param.grad.data = reduced_grad
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad
# Don't let this memory get reused until after the transfers.
reduced_grad.record_stream(torch.cuda.current_stream())
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance."""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
# A backward pass is done.
self.training_state = TrainingState.IDLE
@torch.no_grad()
def _rebuild_full_params(self) -> None:
"""Gather all shards of params."""
with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision:
self._cast_fp32_param_shards_to_fp16()
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
p.data = p._fp16_shard
continue
p_size = p._full_param_padded.size()
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
assert p_size.numel() % self.world_size == 0
if p._is_sharded:
# Fill p._full_param_padded with (p.data for each shard in self.world_size)
chunks = list(p._full_param_padded.chunk(self.world_size))
dist.all_gather(chunks, p.data, group=self.process_group)
else:
p._full_param_padded.copy_(torch.flatten(p.data), non_blocking=True)
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)
if self.mixed_precision:
self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
@torch.no_grad()
def _use_full_params(self) -> None:
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
assert p._fp16_shard.storage().size() != 0
p.data = p._fp16_shard
else:
assert p._full_param_padded.storage().size() != 0
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)
@torch.no_grad()
def _prep_grads_for_backward(self) -> None:
"""Make sure p.grad has the correct size/device, otherwise set it to None."""
for p in self.params:
if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device):
p.grad = None
@torch.no_grad()
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
"""Free up storage for full parameters."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]):
for p in params:
if not p._is_sharded:
if self.mixed_precision:
self._free_fp16_param_shard([p])
continue
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
p._full_param_padded.record_stream(current_stream)
free_storage_(p._full_param_padded)
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
p.data = p._fp32_shard
@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
"""Cast FP32 param shard to FP16 for a list of params."""
if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If cpu_offload is True, this will be non-blocking because
# _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
@torch.no_grad()
def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
if p._fp16_shard is not None:
# _fp16_shard is allocated in _fp32_to_fp16_stream, so we can't
# free it until the work in the current stream completes.
p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard)
def assert_state(self, state: TrainingState) -> None:
"""Assert we are in the given state."""
assert (
self.training_state == state
), f"expected to be in state {state} but current state is {self.training_state}"
@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
Cast any Tensors in *args or **kwargs to FP16.
Doesn't currently support Tensors nested inside containers (e.g., dict).
"""
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args)
tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs)
flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
return args, kwargs
def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device, dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
setattr(module, key, buf.to(dtype=dtype, device=device))
def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
assert data.storage().size() == data.numel()
data.storage().resize_(0)
@torch.no_grad()
def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
if data.storage().size() == size.numel(): # no need to reallocate
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import functools import functools
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -73,6 +74,23 @@ def set_rng_state(state: Dict[str, Any]) -> None: ...@@ -73,6 +74,23 @@ def set_rng_state(state: Dict[str, Any]) -> None:
torch.cuda.set_rng_state(state["cuda_rng_state"]) torch.cuda.set_rng_state(state["cuda_rng_state"])
def is_autocast_enabled() -> bool:
"""Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1"""
if hasattr(torch, "is_autocast_enabled"):
return torch.is_autocast_enabled()
return False
@contextmanager
def autocast(enabled: bool) -> Generator:
"""Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1"""
if enabled:
with torch.cuda.amp.autocast(enabled):
yield
else:
yield
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""Similar to the torch version, but support non-Tensor outputs. """Similar to the torch version, but support non-Tensor outputs.
...@@ -96,13 +114,13 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -96,13 +114,13 @@ class CheckpointFunction(torch.autograd.Function):
ctx.run_function = run_function ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys ctx.kwarg_keys = kwarg_keys
ctx.fwd_rng_state = get_rng_state() ctx.fwd_rng_state = get_rng_state()
ctx.had_autocast_in_fwd = is_autocast_enabled()
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]: if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
else: else:
ctx.fwd_device, ctx.grad_requirements = None, None ctx.fwd_device, ctx.grad_requirements = None, None
...@@ -142,10 +160,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -142,10 +160,11 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state) set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad(): with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs) tensor_outputs, _ = split_non_tensors(outputs)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state) set_rng_state(bwd_rng_state)
......
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
class FlattenParamsWrapper(nn.Module): class FlattenParamsWrapper(nn.Module):
""" """
...@@ -127,21 +130,23 @@ class FlattenParamsWrapper(nn.Module): ...@@ -127,21 +130,23 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError: except AttributeError:
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict.""" """Return an unflattened state_dict."""
with self.unflatten_params(): with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) return self.module.state_dict(*args, **kwargs)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict.""" """Return the flattened state_dict."""
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None: def load_state_dict(
self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True
) -> NamedTuple:
if "flat_param" in state_dict: if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True) return super().load_state_dict(state_dict, strict=strict)
else: else:
with self.unflatten_params(): with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs) return self.module.load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any: def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views() self._unflatten_params_as_views()
......
...@@ -15,7 +15,7 @@ import torch.distributed as dist ...@@ -15,7 +15,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -284,18 +284,14 @@ class OSS(Optimizer): ...@@ -284,18 +284,14 @@ class OSS(Optimizer):
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set, # Compute the norm on this grad set,
# then sync all the norms from all ranks # then sync all the norms from all ranks
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params) total_norm = local_norm
# all reduce over data parallel and model parallel workers # all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else: else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type,
)
# local norm result can be accumulated with the remote ones if put to the right power # local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p # n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import collections import collections
import io import io
from typing import Any, Callable, Dict, Optional from math import inf
from typing import Any, Callable, Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -102,3 +103,22 @@ class Bucket: ...@@ -102,3 +103,22 @@ class Bucket:
def full(self) -> bool: def full(self) -> bool:
""" is the bucket full ? """ """ is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in return self.max_params_checked_in == self.params_checked_in
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda par: par.grad is not None, parameters))
if len(parameters) == 0:
return torch.tensor(0.0)
p = float(p)
if p == inf:
local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore
else:
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore
return local_norm
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for parallel training."""
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
"""
if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking.
return
world_size = process_group.size()
if "cuda" in str(device):
input_tensor = torch.ones(1).to(device)
output = list(torch.zeros(world_size).to(device).chunk(world_size))
dist.all_gather(output, input_tensor, group=process_group)
assert torch.cat(output).sum() == float(world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={world_size}. Check torch.cuda.set_device is called properly"
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(data[0])
def flush(self) -> None:
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
# copy data from input_list into bucket
stacked_input = torch.stack(input_list).view(world_size, first_input_size)
offset = bucket.offset
bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group)
return self.buckets[key]
...@@ -33,11 +33,12 @@ import os ...@@ -33,11 +33,12 @@ import os
import random import random
import sys import sys
import tempfile import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy import numpy
import pytest import pytest
import torch import torch
from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import rpc from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -46,6 +47,11 @@ import torch.nn as nn ...@@ -46,6 +47,11 @@ import torch.nn as nn
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
if TYPE_CHECKING:
Base = nn.Module[Tensor]
else:
Base = nn.Module
skip_if_no_cuda = pytest.mark.skipif( skip_if_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required" not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
) )
...@@ -75,12 +81,12 @@ if torch.cuda.is_available(): ...@@ -75,12 +81,12 @@ if torch.cuda.is_available():
_, filename_mpi = tempfile.mkstemp() _, filename_mpi = tempfile.mkstemp()
class IdentityLayer(torch.nn.Module): class IdentityLayer(Base):
def __init__(self, size: int, scale: float = 1.0) -> None: def __init__(self, size: int, scale: float = 1.0) -> None:
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size)) self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self, *_: Any, **__: Any) -> Any: def forward(self, *_: Any, **__: Any) -> Tensor:
return self.weight return self.weight
...@@ -103,7 +109,7 @@ def torch_version() -> Tuple[int, ...]: ...@@ -103,7 +109,7 @@ def torch_version() -> Tuple[int, ...]:
# Assuming that we're interested in the second usecase more than the first, # Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering # return the pre-release or dev numbering
logging.warning(f"Pytorch pre-relase version {torch.__version__} - assuming intent to test it") logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it")
numbering[2] = "0" numbering[2] = "0"
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
...@@ -301,7 +307,7 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -301,7 +307,7 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
return prepare_test return prepare_test
class _Block(nn.Module): class _Block(Base):
def __init__(self, embed_dim: int, num_heads: int) -> None: def __init__(self, embed_dim: int, num_heads: int) -> None:
super().__init__() super().__init__()
self.ln_1 = nn.LayerNorm(embed_dim) self.ln_1 = nn.LayerNorm(embed_dim)
...@@ -309,7 +315,7 @@ class _Block(nn.Module): ...@@ -309,7 +315,7 @@ class _Block(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
def forward(self, *inputs: Any, **kwargs: Any) -> Any: def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0] x = inputs[0]
attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype) attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1) attn_mask = torch.triu(attn_mask, diagonal=1)
...@@ -322,7 +328,7 @@ class _Block(nn.Module): ...@@ -322,7 +328,7 @@ class _Block(nn.Module):
return x return x
class GPT2(nn.Module): class GPT2(Base):
""" """
GPT2 pytorch implementation, for testing purposes in the image-GPT context GPT2 pytorch implementation, for testing purposes in the image-GPT context
Credits: https://github.com/teddykoker/image-gpt""" Credits: https://github.com/teddykoker/image-gpt"""
...@@ -349,7 +355,7 @@ class GPT2(nn.Module): ...@@ -349,7 +355,7 @@ class GPT2(nn.Module):
self.head = nn.Linear(embed_dim, num_vocab, bias=False) self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, num_classes) self.clf_head = nn.Linear(embed_dim, num_classes)
def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore def forward(self, x: Tensor, classify: bool = False) -> Any: # type: ignore
""" """
Expect input as shape [sequence len, batch] Expect input as shape [sequence len, batch]
If classify, return classification logits If classify, return classification logits
...@@ -451,3 +457,89 @@ def check_same_models_across_ranks( ...@@ -451,3 +457,89 @@ def check_same_models_across_ranks(
assert not params_should_be_equal or torch.all( assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b) torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks" ), "Models differ in between ranks"
class DeviceAndTypeCheckModule(Base):
"""A simple module for checking Tensor devices and dtypes."""
def __init__(
self,
expected_input_dtype: Optional[torch.dtype] = None,
expected_input_device: Optional[torch.device] = None,
expected_param_dtype: Optional[torch.dtype] = None,
expected_param_device: Optional[torch.device] = None,
expected_loss_dtype: Optional[torch.dtype] = None,
expected_loss_device: Optional[torch.device] = None,
):
super().__init__()
self.expected_input_dtype = expected_input_dtype
self.expected_input_device = expected_input_device
self.expected_param_dtype = expected_param_dtype
self.expected_param_device = expected_param_device
self.expected_loss_dtype = expected_loss_dtype
self.expected_loss_device = expected_loss_device
self.linear = nn.Linear(5, 5)
def _check(
self,
key: str,
x: Union[torch.device, torch.dtype],
expected: Union[Optional[torch.device], Optional[torch.dtype]],
) -> None:
assert expected in {None, x}, f"{key} ({x}) != expected ({expected})"
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
x = input[0]
self._check("input.dtype", x.dtype, self.expected_input_dtype)
self._check("input.device", x.device, self.expected_input_device)
param = self.linear.weight
self._check("param.dtype", param.dtype, self.expected_param_dtype)
self._check("param.device", param.device, self.expected_param_device)
loss = self.linear(x).sum()
self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
self._check("loss.device", loss.device, self.expected_loss_device)
return loss
@functools.lru_cache()
def get_cycles_per_ms() -> float:
"""Approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
..note::
This doesn't seems to return consistent cycles on desktop GPUs likely
due to frequency scaling.
>>> get_cycles_per_ms()
227.6441091140009
# new python process
>>> get_cycles_per_ms()
564.652154766248
# new python process
>>> get_cycles_per_ms()
245.56459442962856
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
class DummyProcessGroup:
def __init__(self, rank: int, size: int):
self._rank = rank
self._size = size
def rank(self) -> int:
return self._rank
def size(self) -> int:
return self._size
...@@ -13,3 +13,5 @@ pytest-cov == 2.10.0 ...@@ -13,3 +13,5 @@ pytest-cov == 2.10.0
pytest-mpi == 0.4 pytest-mpi == 0.4
pytest-timeout == 1.4.2 pytest-timeout == 1.4.2
mpi4py == 3.0.3 mpi4py == 3.0.3
remote-pdb >= 2.1.0
parameterized >= 0.8.1
...@@ -84,6 +84,7 @@ class Size(tuple): ...@@ -84,6 +84,7 @@ class Size(tuple):
class Storage: class Storage:
def size(self) -> _int: ... def size(self) -> _int: ...
def element_size(self) -> _int: ... def element_size(self) -> _int: ...
def resize_(self, int) -> None: ...
#END #END
# See https://github.com/python/mypy/issues/4146 for why these workarounds # See https://github.com/python/mypy/issues/4146 for why these workarounds
...@@ -1913,6 +1914,7 @@ def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API ...@@ -1913,6 +1914,7 @@ def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def set_default_dtype(d : _dtype) -> None: ... def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ... def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> _bool: ... def compiled_with_cxx11_abi() -> _bool: ...
def is_autocast_enabled() -> _bool: ...
# The return value of this function depends on the value of `as_tuple`, # The return value of this function depends on the value of `as_tuple`,
# (similar to `unique`, `lu`, etc.); as such, it is not # (similar to `unique`, `lu`, etc.); as such, it is not
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Generator
from .grad_scaler import GradScaler as GradScaler from .grad_scaler import GradScaler as GradScaler
class autocast:
def __init__(self, enabled=True) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
...@@ -37,12 +37,15 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce ...@@ -37,12 +37,15 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ... def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ... def new_group(ranks: Optional[List[int]] = None,
timeout: Optional[datetime.timedelta] = datetime.timedelta(0, 1800),
backend: Optional[Union[str, Backend]] = None): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def destroy_process_group() -> None: ... def destroy_process_group() -> None: ...
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from ... import Tensor, device, dtype from ... import Tensor, device, dtype
from .. import Parameter from .. import Parameter
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic, NamedTuple
from collections import OrderedDict from collections import OrderedDict
from ...utils.hooks import RemovableHandle from ...utils.hooks import RemovableHandle
...@@ -65,9 +65,10 @@ class Module(Generic[T_co]): ...@@ -65,9 +65,10 @@ class Module(Generic[T_co]):
def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ... def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ...
# TODO double-check this
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ... def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ...
def __setstate__(self, state: Dict[str, Any]) -> None: ...
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned. # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor]) T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
...@@ -78,7 +79,7 @@ class Module(Generic[T_co]): ...@@ -78,7 +79,7 @@ class Module(Generic[T_co]):
@overload @overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ... def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ...
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...): ... def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...) -> NamedTuple: ...
def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ... def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .. import Tensor from typing import Optional
from .. import Size, Tensor
from ..cuda import Stream
import builtins import builtins
class Parameter(Tensor): class Parameter(Tensor):
# These are dynamic attributes added by shard_params_data_parallel class.
# Added here for better type checking.
_is_sharded: bool
_orig_size: Size
_cpu_grad: Tensor
_full_param_padded: Tensor
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]
def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...
... ...
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
from math import inf
import pickle
import sys
from typing import Dict
import unittest
from unittest import mock
from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import (
DeviceAndTypeCheckModule,
DummyProcessGroup,
dist_init,
get_cycles_per_ms,
objects_are_equal,
spawn_for_all_world_sizes,
)
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
_BUFFER_NAME = "vocab_bias"
class DistributedTest(unittest.TestCase):
def setUp(self):
major, minor = torch.__version__.split(".")[:2]
major, minor = int(major), int(minor)
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
@staticmethod
def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
if norm_type is not None:
clip_norm = 0.3
if isinstance(model, FullyShardedDataParallel):
model.clip_grad_norm_(clip_norm, norm_type)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
optim.step()
if hasattr(model, "assert_idle"):
model.assert_idle()
return loss.detach()
@staticmethod
def get_wrapped_model(group, cuda_first=False, config={}, **model_kwargs) -> FullyShardedDataParallel:
if cuda_first:
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs).cuda(), group, **config)
else:
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda()
return model
class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
self._spawn_test_case(
{"mixed_precision": False},
False, # autocast enabled
torch.float32, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)
def test_mixed_precision(self):
self._spawn_test_case(
{"mixed_precision": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
)
def test_mixed_precision_autocast(self):
"""If autocast enabled, loss should be fp32."""
self._spawn_test_case(
{"mixed_precision": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
)
def test_mixed_precision_autocast_fp32_compute(self):
self._spawn_test_case(
{"mixed_precision": True, "compute_dtype": torch.float32},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)
def test_fp32_reduce_scatter(self):
self._spawn_test_case(
{"mixed_precision": True, "fp32_reduce_scatter": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)
def test_fp32_reduce_scatter_autocast(self):
self._spawn_test_case(
{"mixed_precision": True, "fp32_reduce_scatter": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)
def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype)
spawn_and_init(fn, world_sizes=[world_size])
@staticmethod
def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter = torch.distributed.reduce_scatter
model = DeviceAndTypeCheckModule(
expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype,
)
def _reduce_scatter(output, input_list, **kwargs):
for tensor in input_list:
model._check("reduce_scatter.dtype", tensor.dtype, expected=reduce_dtype)
return orig_reduce_scatter(output, input_list, **kwargs)
with mock.patch("torch.distributed.reduce_scatter", new=_reduce_scatter):
model = FullyShardedDataParallel(model, group, **cfg).cuda()
device = next(model.parameters()).device
x = torch.rand(2, 5).to(device)
with torch.cuda.amp.autocast(enabled=autocast):
loss = model(x)
loss.backward()
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"]
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)
class TestComparisonToPyTorchDDP(DistributedTest):
"""
Compare losses and parameter values after several updates when using
PyTorch DDP vs. FullyShardedDataParallel.
"""
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
# Test every combination of these options:
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config))
def test_cpu_offload_and_cpu_grads(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
)
spawn_and_init(test_fn)
def test_cpu_offload_and_cuda_grads_breaks(self):
# If grads are on gpu, but model and optimizer are on cpu, backward breaks.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False}
with self.assertRaises(Exception): # RuntimeError inside spawn
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False
)
spawn_and_init(test_fn)
def test_delayed_optim_step(self):
# We use a model with a long CUDA delay right before the optimizer step.
# This tests our streams logic, and that we don't start the FP32 -> FP16
# transfer until after the optimization step completes.
config = {"mixed_precision": True}
model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_after_loss_ms=250)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
def test_delayed_reduce_scatter(self):
# We insert a delay in the torch.distributed.reduce_scatter op, so that
# the post_backward_stream takes much longer than the backward pass.
# This tests that we properly block at the end of the backward pass for
# the reductions to finish.
config = {"mixed_precision": True}
model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_before_reduction_ms=250)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test)
def test_mixture_of_experts(self, moe_config):
fsdp_config = {"mixed_precision": True}
test_fn = functools.partial(
self._test_identical_outputs,
functools.partial(MixtureOfExperts, **moe_config),
fsdp_config,
# MixtureOfExperts implements custom reduce logic, so the reference
# behavior should use that logic instead of PyTorch DDP.
ref_ddp_fn=self._dummy_ddp_fn,
norm_type=None,
)
spawn_and_init(test_fn)
def test_mixture_of_experts_grad_clip_breaks(self):
config = {"mixed_precision": True}
test_fn = functools.partial(
self._test_identical_outputs, MixtureOfExperts, config, ref_ddp_fn=self._dummy_ddp_fn, norm_type=2,
)
with self.assertRaises(Exception):
spawn_and_init(test_fn)
@classmethod
def _dummy_ddp_fn(self, model, group):
return DummyDDP(model)
@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config["mixed_precision"]:
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
@parameterized.expand([[1], [inf]], name_func=rename_test)
def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, norm_type=norm_type,
)
spawn_and_init(test_fn)
class TestParamInit(DistributedTest):
def test_param_change_after_init(self):
test_fn = functools.partial(self._test_param_change_after_init, config={"mixed_precision": True})
spawn_and_init(test_fn)
@classmethod
def _test_param_change_after_init(self, rank, group, config):
# Establish reference behavior.
model = self.get_wrapped_model(group, cuda_first=False, config=config)
model.eval() # no dropout for this test
input = model.module.get_input(torch.device("cuda"))
ref_output = model(*input)
# Change the weights in place.
model = self.get_wrapped_model(group, cuda_first=False, config=config)
model.eval() # no dropout for this test
first_param = next(model.parameters())
nn.init.normal_(first_param.data)
new_output = model(*input)
assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init"
class TestSerialization(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_pickle(self, mixed_precision, cpu_offload):
"""Ensure that wrapped modules can be pickled/unpickled."""
config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload}
test_fn = functools.partial(self._test_pickle, config=config)
spawn_and_init(test_fn, world_sizes=[2])
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_multiprocessing(self, mixed_precision, cpu_offload):
"""Ensure that wrapped modules can be sent via multiprocessing."""
config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload}
test_fn = functools.partial(self._test_multiprocessing, config=config)
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_pickle(self, rank, group, config):
model = self._get_model(group, config)
model = pickle.loads(pickle.dumps(model))
if not config["cpu_offload"]:
model = model.cuda()
self._one_step(model, group)
@classmethod
def _test_multiprocessing(self, rank, group, config):
mp = torch.multiprocessing.Pool(1)
dummy_group = DummyProcessGroup(rank=group.rank(), size=group.size())
model = mp.apply(self._get_model, (dummy_group, config))
if not config["cpu_offload"]:
model = model.cuda()
self._one_step(model, group)
@classmethod
def _get_model(self, group, config):
with torch.no_grad(): # required for multiprocessing
model = NestedWrappedModule(group, wrapper_config=config)
return FullyShardedDataParallel(model, group, **config)
@classmethod
def _one_step(self, model, group):
# reset the process group (required after unpickling)
for m in model.modules():
if isinstance(m, FullyShardedDataParallel):
m.process_group = group
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output)
model.module.run_backward(loss)
optim.step()
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
test_fn = functools.partial(
self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model)
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)
assert state_1.keys() == state_2.keys()
# Assert that parameters were updated since before training
unchanged = []
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs)
autocast = ddp_model.mixed_precision
self._train_for_several_steps(ddp_model, 1, autocast)
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_after_forward_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_before_forward(self, mixed_precision):
test_fn = functools.partial(
self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32
wt = sd["embed_tokens.weight"]
assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
@parameterized.expand([[True], [False]])
def test_output_backward_hooks(self, cuda_first):
fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first)
spawn_and_init(fn)
def test_backward_hooks_after_save(self):
fn = functools.partial(self._test_backward_hooks_after_save, cuda_first=False)
spawn_and_init(fn)
@classmethod
def _test_backward_hooks_after_save(self, rank, group, cuda_first=False):
model = self.get_wrapped_model(group, cuda_first=cuda_first)
self._train_for_several_steps(model, 2, model.mixed_precision)
state_1 = model.local_state_dict()
model.load_local_state_dict(state_1)
self._test_output_backward_hooks(rank, group, cuda_first=cuda_first, model=model)
@classmethod
def _test_output_backward_hooks(self, rank, group, cuda_first=False, model=None):
if model is None:
model = self.get_wrapped_model(group, cuda_first=cuda_first)
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim.zero_grad()
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
assert len(output._backward_hooks) == 1 # this is pre-bwd hook
loss = model.module.get_loss(input, output).cuda()
loss.backward()
assert len(output._backward_hooks) == 1 # It doesn't get removed
optim.step()
assert len(output._backward_hooks) == 1
@parameterized.expand([[True], [False]])
def test_register_functions_called(self, cuda_first):
fn = functools.partial(self._test_register_functions_called, cuda_first=cuda_first)
spawn_and_init(fn)
@classmethod
def _test_register_functions_called(self, rank, group, cuda_first=False):
"""Tests that _register_{pre|post}_backward_hooks called during forward."""
model = self.get_wrapped_model(group, cuda_first=cuda_first)
input = model.module.get_input(torch.device("cuda"))
model._register_post_backward_hooks = mock.MagicMock(return_value=None)
model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
assert not model._register_post_backward_hooks.called
assert not model._register_pre_backward_hooks.called
model(*input)
assert model._register_post_backward_hooks.called
assert model._register_pre_backward_hooks.called
class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
test_fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_transformer(self, rank, group, config):
autocast = config["mixed_precision"]
# Train model for a step
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast)
model.eval() # no dropout for this test
# Eval in standard mode (i.e., without no_grad)
input = model.module.get_input(torch.device("cuda"))
ref_output = model(*input)
# Eval with no_grad and compare
with torch.no_grad():
no_grad_output = model(*input)
assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output"
class TestNoSync(DistributedTest):
def test_transformer(self):
fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn)
def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(fn)
def test_nested_wrapper(self):
fn = functools.partial(self._test_nested_wrapper, config={})
spawn_and_init(fn)
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={})
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
@classmethod
def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0)
@classmethod
def _test_no_sync(self, model, batch_dim):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple)
batch2 = tuple(
# This randomly permutes the values in a multi-dim tensor.
x.view(-1)[torch.randperm(x.numel())].view_as(x)
for x in batch1
)
for x, y in zip(batch1, batch2):
assert not torch.all(x == y)
# Concat the batches along batch dimension.
concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))
# Establish reference behavior on the concat batch.
model.zero_grad()
output = model(*concat_batch)
ref_loss = model.module.get_loss(concat_batch, output)
ref_loss.backward()
ref_grads = [p.grad.detach().clone() for p in model.parameters()]
# Test that we get the same results by accumulating grads.
model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch
output = model(*batch1)
loss1 = model.module.get_loss(batch1, output)
loss1.backward()
output = model(*batch2)
loss2 = model.module.get_loss(batch2, output)
loss2.backward()
accumulated_loss = loss1 + loss2
accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]
torch.testing.assert_allclose(ref_loss, accumulated_loss)
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
torch.manual_seed(0) # keep everything deterministic
assert d_vocab >= 12 # we use torch.arange(12) as input
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,)))
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, 2) # T x B
tgt = torch.arange(8, device=device).view(4, 2) # T x B
return (src, tgt)
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias
tgt = self.embed_tokens(tgt_ids)
x = self.transformer(src, tgt)
return self.output_proj(x)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")
def run_backward(self, loss):
loss.backward()
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
self.wrapper_config = wrapper_config
def _maybe_wrap(layer):
if wrapper_config is not None:
return FullyShardedDataParallel(layer, group, **wrapper_config)
return layer
torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential(
nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8),
)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = output.sum()
return loss
def run_backward(self, loss):
loss.backward()
class DummyDDP(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrapper_config, checkpoint_act=False):
super().__init__(group, wrapper_config)
self.group = group
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
for p in expert.parameters():
p.expert = True
# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, 16)
if checkpoint_act:
expert = checkpoint_wrapper(expert)
shared = checkpoint_wrapper(shared)
if wrapper_config is not None:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()])
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)
shared = FullyShardedDataParallel(shared, group, **wrapper_config)
self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
def run_backward(self, loss):
loss.backward()
# manually reduce gradients if not wrapped in FullyShardedDataParallel
if self.wrapper_config is None:
with torch.no_grad():
for p in self.parameters():
if hasattr(p, "expert"):
continue # these params don't need grad reduction
p.grad.data.div_(self.world_size)
torch.distributed.all_reduce(p.grad.data, group=self.group)
class ModuleWithDelay(nn.Module):
def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0):
super().__init__()
self.delay_after_loss_ms = delay_after_loss_ms
self.delay_before_reduction_ms = delay_before_reduction_ms
self.module = module
def get_input(self, device):
return self.module.get_input(device)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = self.module.get_loss(input, output)
if self.delay_after_loss_ms > 0:
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
return loss
def run_backward(self, loss):
orig_reduce_scatter = torch.distributed.reduce_scatter
def _delayed_reduce_scatter(*args, **kwargs):
if self.delay_before_reduction_ms > 0:
torch.cuda._sleep(int(self.delay_before_reduction_ms * get_cycles_per_ms()))
return orig_reduce_scatter(*args, **kwargs)
with mock.patch("torch.distributed.reduce_scatter", _delayed_reduce_scatter):
self.module.run_backward(loss)
class NestedWrappedModuleWithDelay(ModuleWithDelay):
def __init__(self, group, wrapper_config, **kwargs):
super().__init__(NestedWrappedModule(group, wrapper_config), **kwargs)
def spawn_and_init(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
dist_init(rank, world_size, filename, filename_rpc)
group = torch.distributed.new_group()
fn(rank, group, *args)
if __name__ == "__main__":
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with uneven parameter shards. """
import tempfile
import pytest
import torch
from torch import Tensor
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
if test_case["assert_ref_out"]:
with torch.no_grad():
weight = model.weight.T.clone().cuda()
v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
ref_out = torch.matmul(v, weight)
model.to("cuda")
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=0.1)
inputs = test_case["inputs"]
assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size
for in_data in inputs:
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
if test_case["assert_ref_out"]:
torch.testing.assert_allclose(ref_out, out)
model.assert_state(TrainingState.IDLE)
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize(
"fsdp_config", [{}, {"flatten_parameters": False}],
)
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config):
"""Test FSDP with uneven divide of parameter shards."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs.")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
# TODO (Min): we may want to extend this to a simple 2 layer model so that it covers
# more cases in FSDP. Also, assert_ref_out can be extended to multiple
# iterations. This could be a good bootcamp task. I should file a github
# issue once we merge.
model = Linear(3, 3, bias=False)
mp.spawn(
_test_func,
args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
nprocs=world_size,
join=True,
)
@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3), torch.rand(8, 3)], "assert_ref_out": False}])
@pytest.mark.parametrize("fsdp_config", [{}, {"flatten_parameters": False}])
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_smaller_than_world_size(world_size, test_case, fsdp_config):
"""Test FSDP with uneven divide of parameter shards."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs.")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
model = Sequential(
Linear(3, 3, bias=False),
Linear(3, 4, bias=False),
Linear(4, 5, bias=False),
Linear(5, 4, bias=False),
Linear(4, 3, bias=False),
Linear(3, 1, bias=False),
Linear(1, 1, bias=False), # param here is smaller than world_size if unflattened.
)
mp.spawn(
_test_func,
args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
nprocs=world_size,
join=True,
)
...@@ -631,6 +631,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name): ...@@ -631,6 +631,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
loss_oss = loss_fn(outputs_oss, target) loss_oss = loss_fn(outputs_oss, target)
loss_oss.backward() loss_oss.backward()
torch.testing.assert_allclose(loss_oss, loss)
# Check the equivalence with the non-sharded optim # Check the equivalence with the non-sharded optim
oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm) oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from fairscale.utils.parallel """
from parameterized import parameterized
import torch
from fairscale.utils.parallel import chunk_and_pad
@parameterized.expand([[num_chunks] for num_chunks in range(1, 33)])
def test_chunk_and_pad(num_chunks):
max_tensor_size = 256
tensor = torch.zeros(max_tensor_size)
for tensor_size in range(1, max_tensor_size + 1):
tensor_i = tensor[:tensor_size]
chunks = chunk_and_pad(tensor_i, num_chunks)
assert len(chunks) == num_chunks
assert all(len(chunks[0]) == len(chunk) for chunk in chunks)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
import sys
import unittest
from unittest import mock
from parameterized import parameterized
import torch
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)
CONFIG_OPTIONS = [
[dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144])
]
class TestReduceScatterBucketer(unittest.TestCase):
# TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
def setUp(self):
major, minor = torch.__version__.split(".")[:2]
major, minor = int(major), int(minor)
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_reduce_scatter(self, config):
spawn_and_init(functools.partial(self._test_reduce_scatter, **config))
@staticmethod
def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None):
bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb)
world_size = group.size()
tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)]
tensors[rank].fill_(0)
input_bytes = shard_size * world_size * 4
bucket_bytes = bucket_cap_mb * 1024 * 1024
callback = mock.MagicMock()
bucketer.reduce_scatter_async(tensors, group, callback_fn=callback)
if bucket_cap_mb > 0 and input_bytes < bucket_bytes:
assert callback.call_count == 0
bucketer.flush()
assert callback.call_count == 1
result = callback.call_args[0][0] # get first positional arg
assert torch.is_tensor(result), result
assert torch.all(result == (world_size - 1))
def test_out_of_order_reduction(self):
spawn_and_init(self._test_out_of_order_reduction)
@staticmethod
def _test_out_of_order_reduction(rank, group):
bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25)
world_size = group.size()
small_tensors = [torch.ones(1).cuda() for _ in range(world_size)]
big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)]
more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)]
callback1 = mock.MagicMock()
callback2 = mock.MagicMock()
callback3 = mock.MagicMock()
bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1)
assert callback1.call_count == 0
bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2)
assert callback1.call_count == 0
assert callback2.call_count == 1
bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3)
assert callback1.call_count == 0
assert callback2.call_count == 1
assert callback3.call_count == 0
bucketer.flush()
assert callback1.call_count == 1
assert callback2.call_count == 1
assert callback3.call_count == 1
def spawn_and_init(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
dist_init(rank, world_size, filename, filename_rpc)
group = torch.distributed.new_group()
fn(rank, group, *args)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment