Unverified Commit 15512d9e authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Add FullyShardedDataParallel (FSDP) (#413)

Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and [Google](https://arxiv.org/abs/2004.13336

) has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper, which is a drop-in replacement for PyTorch's `DistributedDataParallel` (DDP) wrapper.

Compared to PyTorch DDP:
* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
* FSDP with `reshard_after_forward=False` has the same communication cost as PyTorch DDP and is similar to ZeRO-2
* FSDP with `reshard_after_forward=True` increases total communication by 50% and is similar to ZeRO-3:
    * all-gather parameters at start of forward pass and start of backward pass
    * reduce-scatter grads at end of backward pass
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 279b8024
......@@ -3,4 +3,5 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .fully_sharded_data_parallel import FullyShardedDataParallel
from .sharded_ddp import ShardedDataParallel
This diff is collapsed.
......@@ -3,8 +3,9 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import functools
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Generator, Optional, Tuple
import torch
from torch import Tensor
......@@ -73,6 +74,23 @@ def set_rng_state(state: Dict[str, Any]) -> None:
torch.cuda.set_rng_state(state["cuda_rng_state"])
def is_autocast_enabled() -> bool:
"""Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1"""
if hasattr(torch, "is_autocast_enabled"):
return torch.is_autocast_enabled()
return False
@contextmanager
def autocast(enabled: bool) -> Generator:
"""Similar to torch.cuda.amp.autocast, but compatible with torch 1.5.1"""
if enabled:
with torch.cuda.amp.autocast(enabled):
yield
else:
yield
class CheckpointFunction(torch.autograd.Function):
"""Similar to the torch version, but support non-Tensor outputs.
......@@ -96,13 +114,13 @@ class CheckpointFunction(torch.autograd.Function):
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
ctx.fwd_rng_state = get_rng_state()
ctx.had_autocast_in_fwd = is_autocast_enabled()
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
else:
ctx.fwd_device, ctx.grad_requirements = None, None
......@@ -142,10 +160,11 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad():
with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)
......
......@@ -2,12 +2,15 @@
# Licensed under the MIT License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
class FlattenParamsWrapper(nn.Module):
"""
......@@ -127,21 +130,23 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict."""
with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
return self.module.state_dict(*args, **kwargs)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
return super().state_dict(*args, **kwargs)
def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None:
def load_state_dict(
self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True
) -> NamedTuple:
if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True)
return super().load_state_dict(state_dict, strict=strict)
else:
with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs)
return self.module.load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views()
......
......@@ -15,7 +15,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -284,18 +284,14 @@ class OSS(Optimizer):
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._default_device) for p in local_params)
total_norm = local_norm
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._default_device) for p in local_params]), # type: ignore
p=norm_type,
)
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
......
......@@ -5,7 +5,8 @@
import collections
import io
from typing import Any, Callable, Dict, Optional
from math import inf
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
......@@ -102,3 +103,22 @@ class Bucket:
def full(self) -> bool:
""" is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda par: par.grad is not None, parameters))
if len(parameters) == 0:
return torch.tensor(0.0)
p = float(p)
if p == inf:
local_norm = max(par.grad.detach().abs().max() for par in parameters) # type: ignore
else:
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]), p) # type: ignore
return local_norm
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for parallel training."""
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
"""
if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking.
return
world_size = process_group.size()
if "cuda" in str(device):
input_tensor = torch.ones(1).to(device)
output = list(torch.zeros(world_size).to(device).chunk(world_size))
dist.all_gather(output, input_tensor, group=process_group)
assert torch.cat(output).sum() == float(world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={world_size}. Check torch.cuda.set_device is called properly"
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(data[0])
def flush(self) -> None:
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
# copy data from input_list into bucket
stacked_input = torch.stack(input_list).view(world_size, first_input_size)
offset = bucket.offset
bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group)
return self.buckets[key]
......@@ -33,11 +33,12 @@ import os
import random
import sys
import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy
import pytest
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import rpc
import torch.multiprocessing as mp
......@@ -46,6 +47,11 @@ import torch.nn as nn
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
if TYPE_CHECKING:
Base = nn.Module[Tensor]
else:
Base = nn.Module
skip_if_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
)
......@@ -75,12 +81,12 @@ if torch.cuda.is_available():
_, filename_mpi = tempfile.mkstemp()
class IdentityLayer(torch.nn.Module):
class IdentityLayer(Base):
def __init__(self, size: int, scale: float = 1.0) -> None:
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self, *_: Any, **__: Any) -> Any:
def forward(self, *_: Any, **__: Any) -> Tensor:
return self.weight
......@@ -103,7 +109,7 @@ def torch_version() -> Tuple[int, ...]:
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-relase version {torch.__version__} - assuming intent to test it")
logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it")
numbering[2] = "0"
return tuple(int(n) for n in numbering)
......@@ -301,7 +307,7 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
return prepare_test
class _Block(nn.Module):
class _Block(Base):
def __init__(self, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
......@@ -309,7 +315,7 @@ class _Block(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0]
attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
......@@ -322,7 +328,7 @@ class _Block(nn.Module):
return x
class GPT2(nn.Module):
class GPT2(Base):
"""
GPT2 pytorch implementation, for testing purposes in the image-GPT context
Credits: https://github.com/teddykoker/image-gpt"""
......@@ -349,7 +355,7 @@ class GPT2(nn.Module):
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, num_classes)
def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore
def forward(self, x: Tensor, classify: bool = False) -> Any: # type: ignore
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
......@@ -451,3 +457,89 @@ def check_same_models_across_ranks(
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks"
class DeviceAndTypeCheckModule(Base):
"""A simple module for checking Tensor devices and dtypes."""
def __init__(
self,
expected_input_dtype: Optional[torch.dtype] = None,
expected_input_device: Optional[torch.device] = None,
expected_param_dtype: Optional[torch.dtype] = None,
expected_param_device: Optional[torch.device] = None,
expected_loss_dtype: Optional[torch.dtype] = None,
expected_loss_device: Optional[torch.device] = None,
):
super().__init__()
self.expected_input_dtype = expected_input_dtype
self.expected_input_device = expected_input_device
self.expected_param_dtype = expected_param_dtype
self.expected_param_device = expected_param_device
self.expected_loss_dtype = expected_loss_dtype
self.expected_loss_device = expected_loss_device
self.linear = nn.Linear(5, 5)
def _check(
self,
key: str,
x: Union[torch.device, torch.dtype],
expected: Union[Optional[torch.device], Optional[torch.dtype]],
) -> None:
assert expected in {None, x}, f"{key} ({x}) != expected ({expected})"
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
x = input[0]
self._check("input.dtype", x.dtype, self.expected_input_dtype)
self._check("input.device", x.device, self.expected_input_device)
param = self.linear.weight
self._check("param.dtype", param.dtype, self.expected_param_dtype)
self._check("param.device", param.device, self.expected_param_device)
loss = self.linear(x).sum()
self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
self._check("loss.device", loss.device, self.expected_loss_device)
return loss
@functools.lru_cache()
def get_cycles_per_ms() -> float:
"""Approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
..note::
This doesn't seems to return consistent cycles on desktop GPUs likely
due to frequency scaling.
>>> get_cycles_per_ms()
227.6441091140009
# new python process
>>> get_cycles_per_ms()
564.652154766248
# new python process
>>> get_cycles_per_ms()
245.56459442962856
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
class DummyProcessGroup:
def __init__(self, rank: int, size: int):
self._rank = rank
self._size = size
def rank(self) -> int:
return self._rank
def size(self) -> int:
return self._size
......@@ -13,3 +13,5 @@ pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0
parameterized >= 0.8.1
......@@ -84,6 +84,7 @@ class Size(tuple):
class Storage:
def size(self) -> _int: ...
def element_size(self) -> _int: ...
def resize_(self, int) -> None: ...
#END
# See https://github.com/python/mypy/issues/4146 for why these workarounds
......@@ -1913,6 +1914,7 @@ def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> _bool: ...
def is_autocast_enabled() -> _bool: ...
# The return value of this function depends on the value of `as_tuple`,
# (similar to `unique`, `lu`, etc.); as such, it is not
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Generator
from .grad_scaler import GradScaler as GradScaler
class autocast:
def __init__(self, enabled=True) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
......@@ -37,12 +37,15 @@ def broadcast_object_list(object_list: List[Any], src: int, group:Optional[Proce
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def new_group(ranks: Optional[List[int]] = None,
timeout: Optional[datetime.timedelta] = datetime.timedelta(0, 1800),
backend: Optional[Union[str, Backend]] = None): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def reduce_scatter(tensor: Tensor, input_list: List[Tensor], op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def destroy_process_group() -> None: ...
......
......@@ -2,7 +2,7 @@
from ... import Tensor, device, dtype
from .. import Parameter
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic, NamedTuple
from collections import OrderedDict
from ...utils.hooks import RemovableHandle
......@@ -65,9 +65,10 @@ class Module(Generic[T_co]):
def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ...
# TODO double-check this
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ...
def __setstate__(self, state: Dict[str, Any]) -> None: ...
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
......@@ -78,7 +79,7 @@ class Module(Generic[T_co]):
@overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ...
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...): ...
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...) -> NamedTuple: ...
def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .. import Tensor
from typing import Optional
from .. import Size, Tensor
from ..cuda import Stream
import builtins
class Parameter(Tensor):
# These are dynamic attributes added by shard_params_data_parallel class.
# Added here for better type checking.
_is_sharded: bool
_orig_size: Size
_cpu_grad: Tensor
_full_param_padded: Tensor
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]
def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...
...
This diff is collapsed.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with uneven parameter shards. """
import tempfile
import pytest
import torch
from torch import Tensor
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
if test_case["assert_ref_out"]:
with torch.no_grad():
weight = model.weight.T.clone().cuda()
v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
ref_out = torch.matmul(v, weight)
model.to("cuda")
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=0.1)
inputs = test_case["inputs"]
assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size
for in_data in inputs:
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
if test_case["assert_ref_out"]:
torch.testing.assert_allclose(ref_out, out)
model.assert_state(TrainingState.IDLE)
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize(
"fsdp_config", [{}, {"flatten_parameters": False}],
)
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config):
"""Test FSDP with uneven divide of parameter shards."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs.")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
# TODO (Min): we may want to extend this to a simple 2 layer model so that it covers
# more cases in FSDP. Also, assert_ref_out can be extended to multiple
# iterations. This could be a good bootcamp task. I should file a github
# issue once we merge.
model = Linear(3, 3, bias=False)
mp.spawn(
_test_func,
args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
nprocs=world_size,
join=True,
)
@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3), torch.rand(8, 3)], "assert_ref_out": False}])
@pytest.mark.parametrize("fsdp_config", [{}, {"flatten_parameters": False}])
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_smaller_than_world_size(world_size, test_case, fsdp_config):
"""Test FSDP with uneven divide of parameter shards."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs.")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
model = Sequential(
Linear(3, 3, bias=False),
Linear(3, 4, bias=False),
Linear(4, 5, bias=False),
Linear(5, 4, bias=False),
Linear(4, 3, bias=False),
Linear(3, 1, bias=False),
Linear(1, 1, bias=False), # param here is smaller than world_size if unflattened.
)
mp.spawn(
_test_func,
args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
nprocs=world_size,
join=True,
)
......@@ -631,6 +631,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
loss_oss = loss_fn(outputs_oss, target)
loss_oss.backward()
torch.testing.assert_allclose(loss_oss, loss)
# Check the equivalence with the non-sharded optim
oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from fairscale.utils.parallel """
from parameterized import parameterized
import torch
from fairscale.utils.parallel import chunk_and_pad
@parameterized.expand([[num_chunks] for num_chunks in range(1, 33)])
def test_chunk_and_pad(num_chunks):
max_tensor_size = 256
tensor = torch.zeros(max_tensor_size)
for tensor_size in range(1, max_tensor_size + 1):
tensor_i = tensor[:tensor_size]
chunks = chunk_and_pad(tensor_i, num_chunks)
assert len(chunks) == num_chunks
assert all(len(chunks[0]) == len(chunk) for chunk in chunks)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
import sys
import unittest
from unittest import mock
from parameterized import parameterized
import torch
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)
CONFIG_OPTIONS = [
[dict(zip(["bucket_cap_mb", "shard_size"], config))] for config in itertools.product([0, 0.25], [1, 262144])
]
class TestReduceScatterBucketer(unittest.TestCase):
# TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
def setUp(self):
major, minor = torch.__version__.split(".")[:2]
major, minor = int(major), int(minor)
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_reduce_scatter(self, config):
spawn_and_init(functools.partial(self._test_reduce_scatter, **config))
@staticmethod
def _test_reduce_scatter(rank, group, bucket_cap_mb=None, shard_size=None):
bucketer = ReduceScatterBucketer(bucket_cap_mb=bucket_cap_mb)
world_size = group.size()
tensors = [torch.ones(shard_size).cuda() for _ in range(world_size)]
tensors[rank].fill_(0)
input_bytes = shard_size * world_size * 4
bucket_bytes = bucket_cap_mb * 1024 * 1024
callback = mock.MagicMock()
bucketer.reduce_scatter_async(tensors, group, callback_fn=callback)
if bucket_cap_mb > 0 and input_bytes < bucket_bytes:
assert callback.call_count == 0
bucketer.flush()
assert callback.call_count == 1
result = callback.call_args[0][0] # get first positional arg
assert torch.is_tensor(result), result
assert torch.all(result == (world_size - 1))
def test_out_of_order_reduction(self):
spawn_and_init(self._test_out_of_order_reduction)
@staticmethod
def _test_out_of_order_reduction(rank, group):
bucketer = ReduceScatterBucketer(bucket_cap_mb=0.25)
world_size = group.size()
small_tensors = [torch.ones(1).cuda() for _ in range(world_size)]
big_tensors = [torch.ones(262144).cuda() for _ in range(world_size)]
more_small_tensors = [torch.ones(2).cuda() for _ in range(world_size)]
callback1 = mock.MagicMock()
callback2 = mock.MagicMock()
callback3 = mock.MagicMock()
bucketer.reduce_scatter_async(small_tensors, group, callback_fn=callback1)
assert callback1.call_count == 0
bucketer.reduce_scatter_async(big_tensors, group, callback_fn=callback2)
assert callback1.call_count == 0
assert callback2.call_count == 1
bucketer.reduce_scatter_async(more_small_tensors, group, callback_fn=callback3)
assert callback1.call_count == 0
assert callback2.call_count == 1
assert callback3.call_count == 0
bucketer.flush()
assert callback1.call_count == 1
assert callback2.call_count == 1
assert callback3.call_count == 1
def spawn_and_init(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
dist_init(rank, world_size, filename, filename_rpc)
group = torch.distributed.new_group()
fn(rank, group, *args)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment