Unverified Commit a05a79bc authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[refactor] enhance wrap and auto_wrap (#467)



* [refactor] enhance wrap and auto_wrap

- Two things were done in this PR
  1. We don't need to import FSDP in wrap.py since the wrapper class
     type is stored in the context now.
  2. We can use a `auto_wrap_policy` function to customize wrapping policy
     for auto_wrap, including size of module, blacklist, exclude list
- The auto_wrap function got simplified a bit as a minor side effect.

* Update fairscale/nn/wrap/auto_wrap.py
Co-authored-by: default avatarSean Naren <sean@grid.ai>

* addressed comments

* addressed more comments
Co-authored-by: default avatarSean Naren <sean@grid.ai>
parent 131a5356
......@@ -4,20 +4,22 @@
# LICENSE file in the root directory of this source tree.
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper
from .misc import FlattenParamsWrapper, checkpoint_wrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, enable_wrap, wrap
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
__all__ = [
"FlattenParamsWrapper",
"checkpoint_wrapper",
"FullyShardedDataParallel",
"LazyModule",
"ShardedDataParallel",
"Pipe",
"PipeRPCWrapper",
"ShardedDataParallel",
"MOELayer",
"Top2Gate",
"auto_wrap",
"default_auto_wrap_policy",
"enable_wrap",
"wrap",
]
from .auto_wrap import auto_wrap, enable_wrap, wrap
from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
......@@ -4,23 +4,77 @@
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast
import torch.nn as nn
from fairscale.nn.data_parallel.fully_sharded_data_parallel import FullyShardedDataParallel
from fairscale.nn.misc import checkpoint_wrapper
# Modules that don't wrap.
FSDP_MODULE_EXCLUDE_WRAP = {nn.ModuleList, nn.ModuleDict}
# Modules that we don't recurse down to their children.
FSDP_MODULE_BLOCKLIST = {nn.MultiheadAttention}
def default_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first three parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first three parameters and free
to do whatever you want in the function.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
min_num_params (int):
Customizable policy input. It controls the size threshold
on how big should a module be to be considered wrapped.
force_leaf_modules (Set[Type[nn.Module]]): set of module types to
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore
if exclude_wrap_modules is None
else exclude_wrap_modules
)
is_large = unwrapped_params >= min_num_params
if recurse:
# We should recurse if the module is big enough but not force_leaf_modulesed.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, we should wrap but not the exclude list.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
@contextlib.contextmanager
def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
Context manager to wrap modules in FullyShardedDataParallel.
Context manager to wrap modules using a wrapper.
Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so
......@@ -34,26 +88,26 @@ def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any)
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules by default based on min_num_params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)
# Wraps children modules based on a different min_num_params
my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7)
self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy)
Args:
module_blocklist: List of additional Module Classes to not wrap when
using :func:`auto_wrap`. This is useful to exclude unsupported
modules when wrapping recursively.
**wrapper_kwargs: Configuration settings that will be passed to all ``wrap``
auto_wrap_policy (Callable, Optional):
Custom function to control how to do :func:`auto_wrap`. This is
useful to exclude unsupported modules or wrap based on sizes when
wrapping recursively. Note: modules annotated with :func:`wrap`
ignore this policy and will always be wrapped.
(default: :func:`default_auto_wrap_policy`)
**wrapper_kwargs:
Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
with ConfigAutoWrap(module_blocklist, **wrapper_kwargs):
with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs):
yield
def wrap(
module: nn.Module,
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**wrap_overrides: Any
) -> nn.Module:
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important
......@@ -68,32 +122,22 @@ def wrap(
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
if activation_checkpoint:
module = checkpoint_wrapper(module)
return cls(module, **wrap_overrides)
assert ConfigAutoWrap.wrapper_cls is not None
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module
def auto_wrap(
module: nn.Module,
min_num_params: int = int(1e8),
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**kwargs: Any
) -> nn.Module:
def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module:
"""
Annotate that a module should be wrapped with *cls* and recursively wrap
children modules that meet the given criteria. This is useful for wrapping
large complex layers.
Annotate that a module should be wrapped with the *wrapper_cls* from the
:func:`enable_wrap` context (if the context exists) and recursively wrap
children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
......@@ -104,22 +148,18 @@ def auto_wrap(
Usage::
with enable_wrap(**params):
# Wraps children modules by default based on min_num_params
self.l1 = auto_wrap(TransformerBlock(), min_num_params=1e8)
# Wraps children modules.
self.l1 = auto_wrap(TransformerBlock())
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
min_num_params (int, Optional): min number of parameters for a child
Module to be wrapped
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
module (nn.Module):
module to wrap (if in :func:`enable_wrap` context)
auto_wrap_policy (Callable):
a function to determine should Module to be wrapped.
(default: wrap if > 100M parameters)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, cls=cls, activation_checkpoint=activation_checkpoint, min_num_params=min_num_params, **kwargs
)
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(module, auto_wrap_policy=auto_wrap_policy, **kwargs)
return wrapped_module
return module
......@@ -130,60 +170,76 @@ class ConfigAutoWrap:
See :func:`enable_wrap` for more information.
"""
module_blocklist: List = []
in_autowrap_context: bool = False
kwargs: Dict[str, Any] = {}
in_autowrap_context: bool = False # Context flag
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
def __init__(self, module_blocklist: Optional[List] = None, **kwargs: Dict[str, Any]):
if module_blocklist:
self.module_blocklist += module_blocklist
def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]):
self.auto_wrap_policy = auto_wrap_policy
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(kwargs: Any) -> None:
def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.
ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy
ConfigAutoWrap.kwargs = kwargs
ConfigAutoWrap.module_blocklist += FSDP_MODULE_BLOCKLIST
@staticmethod
def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.module_blocklist = []
ConfigAutoWrap.auto_wrap_policy = None
def __enter__(self) -> None:
self.enable_autowrap_context(self.kwargs)
self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(module: nn.Module, min_num_params: int, **kwargs: Any) -> Tuple[nn.Module, int]:
def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kwargs: Any) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
Args:
module (nn.Module): module to recursively wrap
min_num_params (int): min number of parameters for a child Module to
be wrapped
module (nn.Module):
module to recursively wrap
auto_wrap_policy (Callable, Optional):
optionally, override the :func:`auto_wrap_policy` from the context.
Returns:
(nn.Module, int):
Wrapped module and the number parameters wrapped recursively.
"""
if isinstance(module, tuple(ConfigAutoWrap.module_blocklist)):
# If the module has been blocklisted from wrapping, we return
return module, 0
if auto_wrap_policy is None:
auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy
# Make sure no child is not already wrapped.
for _, child in module.named_modules():
assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls))
# We count all params, assuming none of them is already wrapped.
num_params = sum([p.numel() for p in module.parameters()])
if num_params >= min_num_params:
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, min_num_params=min_num_params, **kwargs
module=child, auto_wrap_policy=auto_wrap_policy, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
......@@ -191,7 +247,8 @@ class ConfigAutoWrap:
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if remainder >= min_num_params and not isinstance(module, tuple(FSDP_MODULE_EXCLUDE_WRAP)):
if auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params
else:
return module, total_wrapped_params
......
......@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
import unittest
from unittest import mock
......@@ -12,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap, wrap
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup
......@@ -25,7 +26,7 @@ class TestAutoWrap(unittest.TestCase):
self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, FSDP)
assert layer.flatten_parameters is False
......@@ -35,7 +36,7 @@ class TestAutoWrap(unittest.TestCase):
assert isinstance(layer, nn.Linear)
def test_wrap_override_defaults(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group):
with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
assert isinstance(layer, FSDP)
assert layer.flatten_parameters
......@@ -45,11 +46,12 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear)
......@@ -62,9 +64,10 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
min_num_params.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
model = auto_wrap(sequential, min_num_params=40)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], nn.Linear)
......@@ -74,31 +77,43 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure excluded modules are not wrapped, but children are if param size is greater than
min_num_params
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(10, 10)])
model = auto_wrap(sequential, min_num_params=40)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], FSDP)
def test_auto_wrap_preset_blocklist(self):
def test_auto_wrap_preset_force_leaf(self):
"""
Test to ensure blocklisted modules are not wrapped, and children are not wrapped.
Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
"""
with enable_wrap(process_group=self.process_group, flatten_parameters=False):
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
model = auto_wrap(sequential, min_num_params=40)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model.module[0], FSDP)
# Assert children of multihead attention are not wrapped
assert isinstance(model.module[1], nn.MultiheadAttention)
assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_blocklist_custom(self):
def test_auto_wrap_preset_force_leaf_custom(self):
"""
Test to ensure blocklisted modules are not wrapped.
Test to ensure force-leaf modules are not wrapped.
"""
with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False):
my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy,
min_num_params=40,
force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
)
with enable_wrap(
auto_wrap_policy=my_auto_wrap_policy,
wrapper_cls=FSDP,
process_group=self.process_group,
flatten_parameters=False,
):
sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
model = auto_wrap(sequential, min_num_params=40)
model = auto_wrap(sequential)
# Model was wrapped in FSDP as no inner modules were wrapped.
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
......@@ -123,11 +138,12 @@ class TestAutoWrap(unittest.TestCase):
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(mixed_precision=enable_mixed_precision):
with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
model = auto_wrap(sequential, min_num_params=40)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment