Unverified Commit 9d2bbcf2 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] auto_wrap: support wrapping based on wrapper_config (#685)



* [fix] auto_wrap: support wrapping based on wrapper_config

- user can use this to avoid assert if auto_wrap is used multiple times on a module
- user can traverse the modules multiple times and assign a wrapper_config
  to the module and then use auto_wrap once to wrap them

fix #649
fix #585

* added changelog

* fix tests

* fix a test

* added an optional assert for collision based on discussions with Quentin

* added config_auto_wrap_policy

* lint
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 81c20f72
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- wrap: support wrapping based on `wrapper_config` [#685]
- FSDP: fix extra process groups being created by default. Old behavior can cause excessive GPU memory usage. [#678] - FSDP: fix extra process groups being created by default. Old behavior can cause excessive GPU memory usage. [#678]
- FSDP: fix forward pass not overlapping compute and allgather [#671] - FSDP: fix forward pass not overlapping compute and allgather [#671]
- FSDP: improved frozen weight support [#657] - FSDP: improved frozen weight support [#657]
......
...@@ -10,6 +10,6 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel ...@@ -10,6 +10,6 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = [] __all__: List[str] = []
...@@ -23,7 +23,7 @@ from torch.nn import Parameter ...@@ -23,7 +23,7 @@ from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import ( from fairscale.utils.parallel import (
chunk_and_pad, chunk_and_pad,
...@@ -218,7 +218,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -218,7 +218,7 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload (bool, Optional): cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``. Note: This arg will be deprecated in favor of *``mixed_precision``* is ``True``. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release. *``move_params_to_cpu``* in an upcoming release.
""" """
def __init__( def __init__(
...@@ -1987,6 +1987,8 @@ def auto_wrap_bn( ...@@ -1987,6 +1987,8 @@ def auto_wrap_bn(
single_rank_pg: bool = False, single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
fsdp_config: Optional[Dict[str, Any]] = None, fsdp_config: Optional[Dict[str, Any]] = None,
wrap_it: bool = True,
assert_on_collision: bool = True,
) -> nn.Module: ) -> nn.Module:
""" """
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
...@@ -2008,22 +2010,17 @@ def auto_wrap_bn( ...@@ -2008,22 +2010,17 @@ def auto_wrap_bn(
Optional process group to be used. Optional process group to be used.
fsdp_config (Dict): fsdp_config (Dict):
Optional fsdp_config to be used. Optional fsdp_config to be used.
wrap_it (bool):
Whether or not wrap the module after setting the config.
Default: True
assert_on_collision (bool):
Whether or not assert if a wrapper_config already exists on the module.
Default: True
Returns: Returns:
Processed module, where BNs are wrapped with a special FSDP instance. Processed module, where BNs are wrapped with a special FSDP instance.
""" """
# Prepare a fsdp_config dict for BNs.
def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
if recurse:
return not isinstance(
module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES) # type: ignore
)
else:
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore
)
pg = process_group pg = process_group
if single_rank_pg: if single_rank_pg:
# No sharding with this single member group. # No sharding with this single member group.
...@@ -2032,7 +2029,6 @@ def auto_wrap_bn( ...@@ -2032,7 +2029,6 @@ def auto_wrap_bn(
if fsdp_config is None: if fsdp_config is None:
fsdp_config = { fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": pg, "process_group": pg,
"mixed_precision": False, # Keep the weights in FP32. "mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten. "flatten_parameters": False, # Do not flatten.
...@@ -2047,5 +2043,17 @@ def auto_wrap_bn( ...@@ -2047,5 +2043,17 @@ def auto_wrap_bn(
"force_input_to_fp32": False, "force_input_to_fp32": False,
} }
with enable_wrap(wrap_bn_only_policy, **fsdp_config): # Assign the config dict to BNs.
for m in module.modules():
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
if assert_on_collision:
assert not hasattr(
m, "wrapper_config"
), "Module shouldn't already have a wrapper_config. Is it tagged already by another policy?"
m.wrapper_config = fsdp_config
# Wrap it.
with (
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
):
return auto_wrap(module) return auto_wrap(module)
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
from typing import List from typing import List
from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from .auto_wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = [] __all__: List[str] = []
...@@ -59,10 +59,10 @@ def default_auto_wrap_policy( ...@@ -59,10 +59,10 @@ def default_auto_wrap_policy(
is_large = unwrapped_params >= min_num_params is_large = unwrapped_params >= min_num_params
if recurse: if recurse:
# We should recurse if the module is big enough but not force_leaf_modulesed. # We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules)) return is_large and not isinstance(module, tuple(force_leaf_modules))
else: else:
# If we are not recursing, we should wrap but not the exclude list. # If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules)) return is_large and not isinstance(module, tuple(exclude_wrap_modules))
...@@ -71,6 +71,32 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} ...@@ -71,6 +71,32 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(module: nn.Module, recurse: bool, unwrapped_params: int,) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with
a ``wrapper_config`` attribute.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
"""
if recurse:
# We should always recurse.
return True
else:
# If we are not recursing, determine if we should wrap.
return hasattr(module, "wrapper_config")
@contextlib.contextmanager @contextlib.contextmanager
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]: def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
""" """
...@@ -112,13 +138,20 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A ...@@ -112,13 +138,20 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
""" """
Annotate that a module should be wrapped. Annotated modules will only be Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important wrapped if inside of an :func:`enable_wrap` context manager. This allows
use case is annotating large layers that should be sharded (in-place) during a module to be initialized both with and without a wrapper without code
initialization, to avoid running out of system memory. change.
Both wrapper_cls and wrapper_config can be taken from 3 sources with
increasing priority:
1. ConfigAutoWrap's context
2. module.wrapper_config
3. wrap_overrides argument of this function
Usage:: Usage::
with enable_wrap(**params): with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
# Wraps layer in FSDP by default if within context # Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5)) self.l1 = wrap(torch.nn.Linear(5, 5))
...@@ -128,7 +161,11 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: ...@@ -128,7 +161,11 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
the values provided by the :func:`enable_wrap` context the values provided by the :func:`enable_wrap` context
""" """
if ConfigAutoWrap.in_autowrap_context: if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides} module_overrides = {}
if hasattr(module, "wrapper_config"):
module_overrides = module.wrapper_config
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None assert ConfigAutoWrap.wrapper_cls is not None
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides) return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module return module
...@@ -141,6 +178,13 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, ** ...@@ -141,6 +178,13 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **
children modules that meet the criteria given by :func:`auto_wrap_policy`. This children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers. is useful for wrapping large complex layers.
.. note:: auto_wrap can only be applied to a module once because it
assumes none of the sub-modules is already wrapped and uses that
assumption to compute the wrapped vs. unwrapped parameters.
To get around this limitation, users can pre-assign ``wrapper_config``
attributes to the sub-modules they want to wrap (in multiple passes)
and then uses the ``config_auto_wrap_policy``.
.. warning:: It is not recommended to use :func:`auto_wrap` with .. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared :class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not parameters, as the parameter sharing may be broken (i.e. end up not
......
...@@ -111,3 +111,6 @@ class Module(Generic[T_co]): ...@@ -111,3 +111,6 @@ class Module(Generic[T_co]):
# This is added torchgpipe # This is added torchgpipe
training: bool training: bool
# Added by auto_wrap.py.
wrapper_config: dict
...@@ -76,7 +76,6 @@ def _create_model( ...@@ -76,7 +76,6 @@ def _create_model(
if with_sync_bn: if with_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
fsdp_config = { fsdp_config = {
"wrapper_cls": FSDP,
"mixed_precision": False, "mixed_precision": False,
"flatten_parameters": False, "flatten_parameters": False,
"reshard_after_forward": False, "reshard_after_forward": False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment