Unverified Commit a05a79bc authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[refactor] enhance wrap and auto_wrap (#467)



* [refactor] enhance wrap and auto_wrap

- Two things were done in this PR
  1. We don't need to import FSDP in wrap.py since the wrapper class
     type is stored in the context now.
  2. We can use a `auto_wrap_policy` function to customize wrapping policy
     for auto_wrap, including size of module, blacklist, exclude list
- The auto_wrap function got simplified a bit as a minor side effect.

* Update fairscale/nn/wrap/auto_wrap.py
Co-authored-by: default avatarSean Naren <sean@grid.ai>

* addressed comments

* addressed more comments
Co-authored-by: default avatarSean Naren <sean@grid.ai>
parent 131a5356
...@@ -4,20 +4,22 @@ ...@@ -4,20 +4,22 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper, checkpoint_wrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, enable_wrap, wrap from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
__all__ = [ __all__ = [
"FlattenParamsWrapper", "FlattenParamsWrapper",
"checkpoint_wrapper",
"FullyShardedDataParallel", "FullyShardedDataParallel",
"LazyModule", "ShardedDataParallel",
"Pipe", "Pipe",
"PipeRPCWrapper", "PipeRPCWrapper",
"ShardedDataParallel", "MOELayer",
"Top2Gate", "Top2Gate",
"auto_wrap", "auto_wrap",
"default_auto_wrap_policy",
"enable_wrap", "enable_wrap",
"wrap", "wrap",
] ]
from .auto_wrap import auto_wrap, enable_wrap, wrap from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
...@@ -4,23 +4,77 @@ ...@@ -4,23 +4,77 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib import contextlib
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, cast
import torch.nn as nn import torch.nn as nn
from fairscale.nn.data_parallel.fully_sharded_data_parallel import FullyShardedDataParallel
from fairscale.nn.misc import checkpoint_wrapper
# Modules that don't wrap. def default_auto_wrap_policy(
FSDP_MODULE_EXCLUDE_WRAP = {nn.ModuleList, nn.ModuleDict} module: nn.Module,
# Modules that we don't recurse down to their children. recurse: bool,
FSDP_MODULE_BLOCKLIST = {nn.MultiheadAttention} unwrapped_params: int,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first three parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first three parameters and free
to do whatever you want in the function.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
min_num_params (int):
Customizable policy input. It controls the size threshold
on how big should a module be to be considered wrapped.
force_leaf_modules (Set[Type[nn.Module]]): set of module types to
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore
if exclude_wrap_modules is None
else exclude_wrap_modules
)
is_large = unwrapped_params >= min_num_params
if recurse:
# We should recurse if the module is big enough but not force_leaf_modulesed.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, we should wrap but not the exclude list.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
@contextlib.contextmanager @contextlib.contextmanager
def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]: def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
""" """
Context manager to wrap modules in FullyShardedDataParallel. Context manager to wrap modules using a wrapper.
Useful for when you'd like to apply the same parameters to all child modules Useful for when you'd like to apply the same parameters to all child modules
that you wrap. A particularly important use case is wrapping large layers so that you wrap. A particularly important use case is wrapping large layers so
...@@ -34,26 +88,26 @@ def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any) ...@@ -34,26 +88,26 @@ def enable_wrap(module_blocklist: Optional[List] = None, **wrapper_kwargs: Any)
with enable_wrap(**params): with enable_wrap(**params):
# Wraps layer in FSDP by default if within context # Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5)) self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules by default based on min_num_params # Wraps children modules based on a different min_num_params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7)
self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy)
Args: Args:
module_blocklist: List of additional Module Classes to not wrap when auto_wrap_policy (Callable, Optional):
using :func:`auto_wrap`. This is useful to exclude unsupported Custom function to control how to do :func:`auto_wrap`. This is
modules when wrapping recursively. useful to exclude unsupported modules or wrap based on sizes when
**wrapper_kwargs: Configuration settings that will be passed to all ``wrap`` wrapping recursively. Note: modules annotated with :func:`wrap`
ignore this policy and will always be wrapped.
(default: :func:`default_auto_wrap_policy`)
**wrapper_kwargs:
Configuration settings that will be passed to all ``wrap``
instances inside the context instances inside the context
""" """
with ConfigAutoWrap(module_blocklist, **wrapper_kwargs): with ConfigAutoWrap(auto_wrap_policy, **wrapper_kwargs):
yield yield
def wrap( def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
module: nn.Module,
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**wrap_overrides: Any
) -> nn.Module:
""" """
Annotate that a module should be wrapped. Annotated modules will only be Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important wrapped if inside of an :func:`enable_wrap` context manager. An important
...@@ -68,32 +122,22 @@ def wrap( ...@@ -68,32 +122,22 @@ def wrap(
Args: Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context) module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
cls (Callable): class wrapper to wrap the model with if in context
(default: :class:`FullyShardedDataParallel`)
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
**wrap_overrides: configuration overrides that will take priority over **wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context the values provided by the :func:`enable_wrap` context
""" """
if ConfigAutoWrap.in_autowrap_context: if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides} wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
if activation_checkpoint: assert ConfigAutoWrap.wrapper_cls is not None
module = checkpoint_wrapper(module) return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return cls(module, **wrap_overrides)
return module return module
def auto_wrap( def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **kwargs: Any) -> nn.Module:
module: nn.Module,
min_num_params: int = int(1e8),
cls: Callable = FullyShardedDataParallel,
activation_checkpoint: bool = False,
**kwargs: Any
) -> nn.Module:
""" """
Annotate that a module should be wrapped with *cls* and recursively wrap Annotate that a module should be wrapped with the *wrapper_cls* from the
children modules that meet the given criteria. This is useful for wrapping :func:`enable_wrap` context (if the context exists) and recursively wrap
large complex layers. children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers.
.. warning:: It is not recommended to use :func:`auto_wrap` with .. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared :class:`FullyShardedDataParallel` on modules that have shared
...@@ -104,22 +148,18 @@ def auto_wrap( ...@@ -104,22 +148,18 @@ def auto_wrap(
Usage:: Usage::
with enable_wrap(**params): with enable_wrap(**params):
# Wraps children modules by default based on min_num_params # Wraps children modules.
self.l1 = auto_wrap(TransformerBlock(), min_num_params=1e8) self.l1 = auto_wrap(TransformerBlock())
Args: Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context) module (nn.Module):
cls (Callable): class wrapper to wrap the model with if in context module to wrap (if in :func:`enable_wrap` context)
(default: :class:`FullyShardedDataParallel`) auto_wrap_policy (Callable):
min_num_params (int, Optional): min number of parameters for a child a function to determine should Module to be wrapped.
Module to be wrapped (default: wrap if > 100M parameters)
activation_checkpoint (bool): use activation checkpointing wrapper
(default: False)
""" """
if ConfigAutoWrap.in_autowrap_context: if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap( wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(module, auto_wrap_policy=auto_wrap_policy, **kwargs)
module, cls=cls, activation_checkpoint=activation_checkpoint, min_num_params=min_num_params, **kwargs
)
return wrapped_module return wrapped_module
return module return module
...@@ -130,60 +170,76 @@ class ConfigAutoWrap: ...@@ -130,60 +170,76 @@ class ConfigAutoWrap:
See :func:`enable_wrap` for more information. See :func:`enable_wrap` for more information.
""" """
module_blocklist: List = [] in_autowrap_context: bool = False # Context flag
in_autowrap_context: bool = False wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} kwargs: Dict[str, Any] = {} # Wrapper's args
auto_wrap_policy: Optional[Callable] = None # Used only in auto_wrap
def __init__(self, module_blocklist: Optional[List] = None, **kwargs: Dict[str, Any]): def __init__(self, auto_wrap_policy: Optional[Callable] = None, **kwargs: Dict[str, Any]):
if module_blocklist: self.auto_wrap_policy = auto_wrap_policy
self.module_blocklist += module_blocklist
self.kwargs = kwargs self.kwargs = kwargs
@staticmethod @staticmethod
def enable_autowrap_context(kwargs: Any) -> None: def enable_autowrap_context(auto_wrap_policy: Optional[Callable], kwargs: Any) -> None:
if ConfigAutoWrap.in_autowrap_context: if ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError( raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap." "You are already within an autowrap context and we currently do not supported nested autowrap."
) )
ConfigAutoWrap.in_autowrap_context = True ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
assert "wrapper_cls" in kwargs.keys()
ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.
ConfigAutoWrap.auto_wrap_policy = default_auto_wrap_policy if auto_wrap_policy is None else auto_wrap_policy
ConfigAutoWrap.kwargs = kwargs ConfigAutoWrap.kwargs = kwargs
ConfigAutoWrap.module_blocklist += FSDP_MODULE_BLOCKLIST
@staticmethod @staticmethod
def disable_autowrap_context() -> None: def disable_autowrap_context() -> None:
ConfigAutoWrap.in_autowrap_context = False ConfigAutoWrap.in_autowrap_context = False
ConfigAutoWrap.wrapper_cls = None
ConfigAutoWrap.kwargs = {} ConfigAutoWrap.kwargs = {}
ConfigAutoWrap.module_blocklist = [] ConfigAutoWrap.auto_wrap_policy = None
def __enter__(self) -> None: def __enter__(self) -> None:
self.enable_autowrap_context(self.kwargs) self.enable_autowrap_context(self.auto_wrap_policy, self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context() self.disable_autowrap_context()
@staticmethod @staticmethod
def recursive_wrap(module: nn.Module, min_num_params: int, **kwargs: Any) -> Tuple[nn.Module, int]: def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kwargs: Any) -> Tuple[nn.Module, int]:
""" """
Automatically wrap child modules of *module* that meet the given Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`. criteria with :func:`auto_wrap`.
Args: Args:
module (nn.Module): module to recursively wrap module (nn.Module):
min_num_params (int): min number of parameters for a child Module to module to recursively wrap
be wrapped auto_wrap_policy (Callable, Optional):
optionally, override the :func:`auto_wrap_policy` from the context.
Returns:
(nn.Module, int):
Wrapped module and the number parameters wrapped recursively.
""" """
if isinstance(module, tuple(ConfigAutoWrap.module_blocklist)): if auto_wrap_policy is None:
# If the module has been blocklisted from wrapping, we return auto_wrap_policy = ConfigAutoWrap.auto_wrap_policy
return module, 0
# Make sure no child is not already wrapped.
for _, child in module.named_modules():
assert not isinstance(child, cast(type, ConfigAutoWrap.wrapper_cls))
# We count all params, assuming none of them is already wrapped.
num_params = sum([p.numel() for p in module.parameters()]) num_params = sum([p.numel() for p in module.parameters()])
if num_params >= min_num_params: assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
total_wrapped_params = 0 total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary # Iterate through the children, recursively wrap if necessary
for name, child in module.named_children(): for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap( wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, min_num_params=min_num_params, **kwargs module=child, auto_wrap_policy=auto_wrap_policy, **kwargs
) )
setattr(module, name, wrapped_child) setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped # Keep track of how many parameters have been wrapped
...@@ -191,7 +247,8 @@ class ConfigAutoWrap: ...@@ -191,7 +247,8 @@ class ConfigAutoWrap:
# decide if we need to wrap the current module, # decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap # since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params remainder = num_params - total_wrapped_params
if remainder >= min_num_params and not isinstance(module, tuple(FSDP_MODULE_EXCLUDE_WRAP)): if auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params return wrap(module, **kwargs), num_params
else: else:
return module, total_wrapped_params return module, total_wrapped_params
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools
import os import os
import unittest import unittest
from unittest import mock from unittest import mock
...@@ -12,7 +13,7 @@ import torch.nn as nn ...@@ -12,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel as FSDP from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap, wrap from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup from fairscale.utils.testing import DummyProcessGroup
...@@ -25,7 +26,7 @@ class TestAutoWrap(unittest.TestCase): ...@@ -25,7 +26,7 @@ class TestAutoWrap(unittest.TestCase):
self.process_group = DummyProcessGroup(rank=0, size=1) self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self): def test_wrap(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group): with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5)) layer = wrap(nn.Linear(5, 5))
assert isinstance(layer, FSDP) assert isinstance(layer, FSDP)
assert layer.flatten_parameters is False assert layer.flatten_parameters is False
...@@ -35,7 +36,7 @@ class TestAutoWrap(unittest.TestCase): ...@@ -35,7 +36,7 @@ class TestAutoWrap(unittest.TestCase):
assert isinstance(layer, nn.Linear) assert isinstance(layer, nn.Linear)
def test_wrap_override_defaults(self): def test_wrap_override_defaults(self):
with enable_wrap(flatten_parameters=False, process_group=self.process_group): with enable_wrap(wrapper_cls=FSDP, flatten_parameters=False, process_group=self.process_group):
layer = wrap(nn.Linear(5, 5), flatten_parameters=True) layer = wrap(nn.Linear(5, 5), flatten_parameters=True)
assert isinstance(layer, FSDP) assert isinstance(layer, FSDP)
assert layer.flatten_parameters assert layer.flatten_parameters
...@@ -45,11 +46,12 @@ class TestAutoWrap(unittest.TestCase): ...@@ -45,11 +46,12 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
""" """
with enable_wrap(process_group=self.process_group, flatten_parameters=False): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential( sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
) )
model = auto_wrap(sequential, min_num_params=40) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, FSDP) assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear) assert isinstance(model.module[1], nn.Linear)
...@@ -62,9 +64,10 @@ class TestAutoWrap(unittest.TestCase): ...@@ -62,9 +64,10 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
min_num_params. min_num_params.
""" """
with enable_wrap(process_group=self.process_group, flatten_parameters=False): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
model = auto_wrap(sequential, min_num_params=40) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.ModuleList) assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], nn.Linear) assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], nn.Linear) assert isinstance(model[1], nn.Linear)
...@@ -74,31 +77,43 @@ class TestAutoWrap(unittest.TestCase): ...@@ -74,31 +77,43 @@ class TestAutoWrap(unittest.TestCase):
Test to ensure excluded modules are not wrapped, but children are if param size is greater than Test to ensure excluded modules are not wrapped, but children are if param size is greater than
min_num_params min_num_params
""" """
with enable_wrap(process_group=self.process_group, flatten_parameters=False): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.ModuleList([nn.Linear(10, 10)]) sequential = nn.ModuleList([nn.Linear(10, 10)])
model = auto_wrap(sequential, min_num_params=40) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.ModuleList) assert isinstance(model, nn.ModuleList)
assert isinstance(model[0], FSDP) assert isinstance(model[0], FSDP)
def test_auto_wrap_preset_blocklist(self): def test_auto_wrap_preset_force_leaf(self):
""" """
Test to ensure blocklisted modules are not wrapped, and children are not wrapped. Test to ensure force-leaf modules are not wrapped, and children are not wrapped.
""" """
with enable_wrap(process_group=self.process_group, flatten_parameters=False): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
model = auto_wrap(sequential, min_num_params=40) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model.module[0], FSDP) assert isinstance(model.module[0], FSDP)
# Assert children of multihead attention are not wrapped # Assert children of multihead attention are not wrapped
assert isinstance(model.module[1], nn.MultiheadAttention) assert isinstance(model.module[1], nn.MultiheadAttention)
assert isinstance(model.module[1].out_proj, nn.Linear) assert isinstance(model.module[1].out_proj, nn.Linear)
def test_auto_wrap_preset_blocklist_custom(self): def test_auto_wrap_preset_force_leaf_custom(self):
""" """
Test to ensure blocklisted modules are not wrapped. Test to ensure force-leaf modules are not wrapped.
""" """
with enable_wrap(module_blocklist=[nn.Linear], process_group=self.process_group, flatten_parameters=False): my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy,
min_num_params=40,
force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union({nn.Linear}),
)
with enable_wrap(
auto_wrap_policy=my_auto_wrap_policy,
wrapper_cls=FSDP,
process_group=self.process_group,
flatten_parameters=False,
):
sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]))
model = auto_wrap(sequential, min_num_params=40) model = auto_wrap(sequential)
# Model was wrapped in FSDP as no inner modules were wrapped. # Model was wrapped in FSDP as no inner modules were wrapped.
assert isinstance(model, FSDP) assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[0], nn.Linear)
...@@ -123,11 +138,12 @@ class TestAutoWrap(unittest.TestCase): ...@@ -123,11 +138,12 @@ class TestAutoWrap(unittest.TestCase):
torch.cuda.set_device(0) torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(mixed_precision=enable_mixed_precision): with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
sequential = nn.Sequential( sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
) )
model = auto_wrap(sequential, min_num_params=40) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
model.to(device) model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device) input = torch.rand((1, 5), dtype=torch.float).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment