Unverified Commit f3359550 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

[feat] Add context manager to FSDP for easier child module wrapping (#446)

This adds a context manager that assists in making child modules with similar defaults.
Usage:
```
from fairscale.nn.misc import enable_wrap, wrap

with enable_wrap(**handleful_of_important_params):
    layer_1 = wrap(torch.nn.Linear(5, 5))
    layer_2 = wrap(torch.nn.Linear(5, 5), flatten_parameters=True) # Override parameters if you'd like

# without the context manager, creates Linear layer
layer_1 = wrap(torch.nn.Linear(5, 5))
```
If not within the FSDP context, this would be a no-op. This makes it easier to annotate layers without having to copy any changes in parameters.
parent 5eb6b8c7
...@@ -7,6 +7,7 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel ...@@ -7,6 +7,7 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, enable_wrap, wrap
__all__ = [ __all__ = [
"FlattenParamsWrapper", "FlattenParamsWrapper",
...@@ -16,4 +17,7 @@ __all__ = [ ...@@ -16,4 +17,7 @@ __all__ = [
"PipeRPCWrapper", "PipeRPCWrapper",
"ShardedDataParallel", "ShardedDataParallel",
"Top2Gate", "Top2Gate",
"auto_wrap",
"enable_wrap",
"wrap",
] ]
...@@ -793,7 +793,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -793,7 +793,8 @@ class FullyShardedDataParallel(nn.Module):
self._prep_grads_for_backward() self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor: def _register_hook(t: torch.Tensor) -> torch.Tensor:
t.register_hook(_pre_backward_hook) if t.requires_grad:
t.register_hook(_pre_backward_hook)
return t return t
# Attach hooks to Tensor outputs. # Attach hooks to Tensor outputs.
......
from .auto_wrap import auto_wrap, enable_wrap, wrap
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch.nn as nn
from fairscale.nn.data_parallel.fully_sharded_data_parallel import FullyShardedDataParallel
from fairscale.nn.misc import checkpoint_wrapper
# Modules that don't wrap.
FSDP_MODULE_EXCLUDE_WRAP = {nn.ModuleList, nn.ModuleDict}
# Modules that we don't recurse down to their children.
FSDP_MODULE_BLOCKLIST = {nn.MultiheadAttention}
@contextlib.contextmanager
def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
Context manager to wrap modules in FullyShardedDataParallel.
Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so
that they get sharded (in-place) during initialization, to avoid running out of
system memory. Large layers can indicate that they should be sharded via
the ``wrap`` annotation and this context manager can provide the
exact configuration for these nested instances.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules by default based on min_num_params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)
Args:
module_blocklist: List of additional Module Classes to not wrap when
using :func:`auto_wrap`. This is useful to exclude unsupported
modules when wrapping recursively.
**wrapper_kwargs: Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
with ConfigAutoWrap(module_blocklist, **wrapper_kwargs):
yield
def wrap(
module: nn.Module,
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**wrap_overrides: Any
) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important
use case is annotating large layers that should be sharded (in-place) during
initialization, to avoid running out of system memory.
Usage::
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
if activation_checkpoint:
module = checkpoint_wrapper(module)
return cls(module, **wrap_overrides)
return module
def auto_wrap(
module: nn.Module,
min_num_params: int = int(1e8),
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**kwargs: Any
) -> nn.Module:
"""
Annotate that a module should be wrapped with *cls* and recursively wrap
children modules that meet the given criteria. This is useful for wrapping
large complex layers.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
shared) if the shared parameters are not (auto-)wrapped under the same
FSDP wrapper instance.
Usage::
with enable_wrap(**params):
# Wraps children modules by default based on min_num_params
self.l1 = auto_wrap(TransformerBlock(), min_num_params=1e8)
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
min_num_params (int, Optional): min number of parameters for a child
Module to be wrapped
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, cls=cls, activation_checkpoint=activation_checkpoint, min_num_params=min_num_params, **kwargs
)
return wrapped_module
return module
class ConfigAutoWrap:
"""
Helper class to wrap modules based on default config args via a context manager.
See :func:`enable_wrap` for more information.
"""
module_blocklist: List = []
in_autowrap_context: bool = False
kwargs: Dict[str, Any] = {}
def __init__(self, module_blocklist: Optional[List] = None, **kwargs: Dict[str, Any]):
if module_blocklist:
self.module_blocklist += module_blocklist
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
ConfigAutoWrap.in_autowrap_context = True
ConfigAutoWrap.kwargs = kwargs
ConfigAutoWrap.module_blocklist += FSDP_MODULE_BLOCKLIST
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.module_blocklist = []
def __enter__(self) -> None:
self.enable_autowrap_context(self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(module: nn.Module, min_num_params: int, **kwargs: Any) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Args:
module (nn.Module): module to recursively wrap
min_num_params (int): min number of parameters for a child Module to
be wrapped
"""
if isinstance(module, tuple(ConfigAutoWrap.module_blocklist)):
# If the module has been blocklisted from wrapping, we return
return module, 0
num_params = sum([p.numel() for p in module.parameters()])
if num_params >= min_num_params:
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, min_num_params=min_num_params, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_params += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if remainder >= min_num_params and not isinstance(module, tuple(FSDP_MODULE_EXCLUDE_WRAP)):
return wrap(module, **kwargs), num_params
else:
return module, total_wrapped_params
return module, 0
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from .module import Module as Module from .module import Module as Module
from .activation import CELU as CELU, ELU as ELU, GLU as GLU, GELU as GELU, Hardshrink as Hardshrink, \ from .activation import CELU as CELU, ELU as ELU, GLU as GLU, GELU as GELU, Hardshrink as Hardshrink, \
Hardtanh as Hardtanh, LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, PReLU as PReLU, \ Hardtanh as Hardtanh, LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, MultiheadAttention as MultiheadAttention, PReLU as PReLU, \
RReLU as RReLU, ReLU as ReLU, ReLU6 as ReLU6, SELU as SELU, Sigmoid as Sigmoid, Softmax as Softmax, \ RReLU as RReLU, ReLU as ReLU, ReLU6 as ReLU6, SELU as SELU, Sigmoid as Sigmoid, Softmax as Softmax, \
Softmax2d as Softmax2d, Softmin as Softmin, Softplus as Softplus, Softshrink as Softshrink, Softsign as Softsign, \ Softmax2d as Softmax2d, Softmin as Softmin, Softplus as Softplus, Softshrink as Softshrink, Softsign as Softsign, \
Tanh as Tanh, Tanhshrink as Tanhshrink, Threshold as Threshold Tanh as Tanh, Tanhshrink as Tanhshrink, Threshold as Threshold
......
...@@ -208,3 +208,7 @@ class LogSoftmax(Module): ...@@ -208,3 +208,7 @@ class LogSoftmax(Module):
def forward(self, input: Tensor) -> Tensor: ... # type: ignore def forward(self, input: Tensor) -> Tensor: ... # type: ignore
def __call__(self, input: Tensor) -> Tensor: ... # type: ignore def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
class MultiheadAttention(Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float, bias: bool, add_bias_kv: bool, add_zero_attn: bool, kdim: Optional[int], vdim: Optional[int]) -> None: ...
tests/nn/misc/test_flatten_params_wrapper.py tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/wrap/test_wrap.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from unittest import mock
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup
class TestAutoWrap(unittest.TestCase):
def setUp(self) -> None:
version = torch.__version__.split(".")[:2]
major, minor = int(version[0]), int(version[1])
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, FSDP)
assert layer.flatten_parameters is False
def test_wrap_disabled_outside_context(self):
layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, nn.Linear)
def test_wrap_override_defaults(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
assert isinstance(layer, FSDP)
assert layer.flatten_parameters
def test_auto_wrap(self):
"""
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear)
assert isinstance(model.module[2], FSDP)
assert isinstance(model.module[2].module[0], nn.Linear)
assert isinstance(model.module[2].module[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap(self):
"""
Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
min_num_params.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap_include_children(self):
"""
Test to ensure excluded modules are not wrapped, but children are if param size is greater than
min_num_params
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(10, 10)])
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], FSDP)
def test_auto_wrap_preset_blocklist(self):
"""
Test to ensure blocklisted modules are not wrapped, and children are not wrapped.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
model = auto_wrap(sequential, min_num_params=40)
assert isinstance(model.module[0], FSDP)
# Assert children of multihead attention are not wrapped
assert isinstance(model.module[1], nn.MultiheadAttention)
assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_blocklist_custom(self):
"""
Test to ensure blocklisted modules are not wrapped.
"""
with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
model = auto_wrap(sequential, min_num_params=40)
# Model was wrapped in FSDP as no inner modules were wrapped.
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.ModuleList)
# todo: currently complains that address is in use, not sure why since I clear the proc group.
# def test_auto_wrap_smoke(self):
# self._auto_wrap_smoke_test(enable_mixed_precision=False)
def test_auto_wrap_smoke_autocast(self):
"""
Ensure we can do a forward/backward through an auto-wrapped model.
"""
self._auto_wrap_smoke_test(enable_mixed_precision=True)
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def _auto_wrap_smoke_test(self, enable_mixed_precision):
from torch.cuda.amp import autocast
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(mixed_precision=enable_mixed_precision):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(enabled=enable_mixed_precision):
output = model(input)
loss = F.mse_loss(input, output)
loss.backward()
torch.distributed.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment