Unverified Commit f3359550 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

[feat] Add context manager to FSDP for easier child module wrapping (#446)

This adds a context manager that assists in making child modules with similar defaults.
Usage:
```
from fairscale.nn.misc import enable_wrap, wrap

with enable_wrap(**handleful_of_important_params):
    layer_1 = wrap(torch.nn.Linear(5, 5))
    layer_2 = wrap(torch.nn.Linear(5, 5), flatten_parameters=True) # Override parameters if you'd like

# without the context manager, creates Linear layer
layer_1 = wrap(torch.nn.Linear(5, 5))
```
If not within the FSDP context, this would be a no-op. This makes it easier to annotate layers without having to copy any changes in parameters.
parent 5eb6b8c7
......@@ -7,6 +7,7 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, enable_wrap, wrap
__all__ = [
"FlattenParamsWrapper",
......@@ -16,4 +17,7 @@ __all__ = [
"PipeRPCWrapper",
"ShardedDataParallel",
"Top2Gate",
"auto_wrap",
"enable_wrap",
"wrap",
]
......@@ -793,6 +793,7 @@ class FullyShardedDataParallel(nn.Module):
self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad:
t.register_hook(_pre_backward_hook)
return t
......
from .auto_wrap import auto_wrap, enable_wrap, wrap
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch.nn as nn
from fairscale.nn.data_parallel.fully_sharded_data_parallel import FullyShardedDataParallel
from fairscale.nn.misc import checkpoint_wrapper
# Modules that don't wrap.
FSDP_MODULE_EXCLUDE_WRAP = {nn.ModuleList, nn.ModuleDict}
# Modules that we don't recurse down to their children.
FSDP_MODULE_BLOCKLIST = {nn.MultiheadAttention}
@contextlib.contextmanager
def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
Context manager to wrap modules in FullyShardedDataParallel.
Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so
that they get sharded (in-place) during initialization, to avoid running out of
system memory. Large layers can indicate that they should be sharded via
the ``wrap`` annotation and this context manager can provide the
exact configuration for these nested instances.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules by default based on min_num_params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)
Args:
module_blocklist: List of additional Module Classes to not wrap when
using :func:`auto_wrap`. This is useful to exclude unsupported
modules when wrapping recursively.
**wrapper_kwargs: Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
with ConfigAutoWrap(module_blocklist, **wrapper_kwargs):
yield
def wrap(
module: nn.Module,
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**wrap_overrides: Any
) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important
use case is annotating large layers that should be sharded (in-place) during
initialization, to avoid running out of system memory.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
if activation_checkpoint:
module = checkpoint_wrapper(module)
return cls(module, **wrap_overrides)
return module
def auto_wrap(
module: nn.Module,
min_num_params: int = int(1e8),
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**kwargs: Any
) -> nn.Module:
"""
Annotate that a module should be wrapped with *cls* and recursively wrap
children modules that meet the given criteria. This is useful for wrapping
large complex layers.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
shared) if the shared parameters are not (auto-)wrapped under the same
FSDP wrapper instance.
Usage::
with enable_wrap(**params):
# Wraps children modules by default based on min_num_params
self.l1 = auto_wrap(TransformerBlock(), min_num_params=1e8)
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
min_num_params (int, Optional): min number of parameters for a child
Module to be wrapped
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, cls=cls, activation_checkpoint=activation_checkpoint, min_num_params=min_num_params, **kwargs
)
return wrapped_module
return module
class ConfigAutoWrap:
"""
Helper class to wrap modules based on default config args via a context manager.
See :func:`enable_wrap` for more information.
"""
module_blocklist: List = []
in_autowrap_context: bool = False
kwargs: Dict[str, Any] = {}
def __init__(self, module_blocklist: Optional[List] = None, **kwargs: Dict[str, Any]):
if module_blocklist:
self.module_blocklist += module_blocklist
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
ConfigAutoWrap.in_autowrap_context = True
ConfigAutoWrap.kwargs = kwargs
ConfigAutoWrap.module_blocklist += FSDP_MODULE_BLOCKLIST
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.module_blocklist = []
def __enter__(self) -> None:
self.enable_autowrap_context(self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(module: nn.Module, min_num_params: int, **kwargs: Any) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Args:
module (nn.Module): module to recursively wrap
min_num_params (int): min number of parameters for a child Module to
be wrapped
"""
if isinstance(module, tuple(ConfigAutoWrap.module_blocklist)):
# If the module has been blocklisted from wrapping, we return
return module, 0
num_params = sum([p.numel() for p in module.parameters()])
if num_params >= min_num_params:
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, min_num_params=min_num_params, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if remainder >= min_num_params and not isinstance(module, tuple(FSDP_MODULE_EXCLUDE_WRAP)):
return wrap(module, **kwargs), num_params
else:
return module, total_wrapped_params
return module, 0
......@@ -2,7 +2,7 @@
from .module import Module as Module
from .activation import CELU as CELU, ELU as ELU, GLU as GLU, GELU as GELU, Hardshrink as Hardshrink, \
Hardtanh as Hardtanh, LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, PReLU as PReLU, \
Hardtanh as Hardtanh, LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, MultiheadAttention as MultiheadAttention, PReLU as PReLU, \
RReLU as RReLU, ReLU as ReLU, ReLU6 as ReLU6, SELU as SELU, Sigmoid as Sigmoid, Softmax as Softmax, \
Softmax2d as Softmax2d, Softmin as Softmin, Softplus as Softplus, Softshrink as Softshrink, Softsign as Softsign, \
Tanh as Tanh, Tanhshrink as Tanhshrink, Threshold as Threshold
......
......@@ -208,3 +208,7 @@ class LogSoftmax(Module):
def forward(self, input: Tensor) -> Tensor: ... # type: ignore
def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
class MultiheadAttention(Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float, bias: bool, add_bias_kv: bool, add_zero_attn: bool, kdim: Optional[int], vdim: Optional[int]) -> None: ...
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/wrap/test_wrap.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from unittest import mock
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup
class TestAutoWrap(unittest.TestCase):
def setUp(self) -> None:
version = torch.__version__.split(".")[:2]
major, minor = int(version[0]), int(version[1])
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, FSDP)
assert layer.flatten_parameters is False
def test_wrap_disabled_outside_context(self):
layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, nn.Linear)
def test_wrap_override_defaults(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
assert isinstance(layer, FSDP)
assert layer.flatten_parameters
def test_auto_wrap(self):
"""
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear)
assert isinstance(model.module[2], FSDP)
assert isinstance(model.module[2].module[0], nn.Linear)
assert isinstance(model.module[2].module[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap(self):
"""
Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
min_num_params.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap_include_children(self):
"""
Test to ensure excluded modules are not wrapped, but children are if param size is greater than
min_num_params
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(10, 10)])
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], FSDP)
def test_auto_wrap_preset_blocklist(self):
"""
Test to ensure blocklisted modules are not wrapped, and children are not wrapped.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model.module[0], FSDP)
# Assert children of multihead attention are not wrapped
assert isinstance(model.module[1], nn.MultiheadAttention)
assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_blocklist_custom(self):
"""
Test to ensure blocklisted modules are not wrapped.
"""
with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
model = auto_wrap(sequential, min_num_params=40)
# Model was wrapped in FSDP as no inner modules were wrapped.
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.ModuleList)
# todo: currently complains that address is in use, not sure why since I clear the proc group.
# def test_auto_wrap_smoke(self):
# self._auto_wrap_smoke_test(enable_mixed_precision=False)
def test_auto_wrap_smoke_autocast(self):
"""
Ensure we can do a forward/backward through an auto-wrapped model.
"""
self._auto_wrap_smoke_test(enable_mixed_precision=True)
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def _auto_wrap_smoke_test(self, enable_mixed_precision):
from torch.cuda.amp import autocast
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(mixed_precision=enable_mixed_precision):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(enabled=enable_mixed_precision):
output = model(input)
loss = F.mse_loss(input, output)
loss.backward()
torch.distributed.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment