"test/vscode:/vscode.git/clone" did not exist on "7d831a2f9b3ebab9eb8e5c899cf70b103ad6908a"
Unverified Commit 9d2bbcf2 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] auto_wrap: support wrapping based on wrapper_config (#685)



* [fix] auto_wrap: support wrapping based on wrapper_config

- user can use this to avoid assert if auto_wrap is used multiple times on a module
- user can traverse the modules multiple times and assign a wrapper_config
  to the module and then use auto_wrap once to wrap them

fix #649
fix #585

* added changelog

* fix tests

* fix a test

* added an optional assert for collision based on discussions with Quentin

* added config_auto_wrap_policy

* lint
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 81c20f72
......@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- wrap: support wrapping based on `wrapper_config` [#685]
- FSDP: fix extra process groups being created by default. Old behavior can cause excessive GPU memory usage. [#678]
- FSDP: fix forward pass not overlapping compute and allgather [#671]
- FSDP: improved frozen weight support [#657]
......
......@@ -10,6 +10,6 @@ from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = []
......@@ -23,7 +23,7 @@ from torch.nn import Parameter
import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
chunk_and_pad,
......@@ -1987,6 +1987,8 @@ def auto_wrap_bn(
single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
wrap_it: bool = True,
assert_on_collision: bool = True,
) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
......@@ -2008,22 +2010,17 @@ def auto_wrap_bn(
Optional process group to be used.
fsdp_config (Dict):
Optional fsdp_config to be used.
wrap_it (bool):
Whether or not wrap the module after setting the config.
Default: True
assert_on_collision (bool):
Whether or not assert if a wrapper_config already exists on the module.
Default: True
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
"""
def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
if recurse:
return not isinstance(
module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES) # type: ignore
)
else:
return is_bn and not isinstance(
module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES) # type: ignore
)
# Prepare a fsdp_config dict for BNs.
pg = process_group
if single_rank_pg:
# No sharding with this single member group.
......@@ -2032,7 +2029,6 @@ def auto_wrap_bn(
if fsdp_config is None:
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
......@@ -2047,5 +2043,17 @@ def auto_wrap_bn(
"force_input_to_fp32": False,
}
with enable_wrap(wrap_bn_only_policy, **fsdp_config):
# Assign the config dict to BNs.
for m in module.modules():
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
if assert_on_collision:
assert not hasattr(
m, "wrapper_config"
), "Module shouldn't already have a wrapper_config. Is it tagged already by another policy?"
m.wrapper_config = fsdp_config
# Wrap it.
with (
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
):
return auto_wrap(module)
......@@ -5,6 +5,6 @@
from typing import List
from .auto_wrap import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from .auto_wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap
__all__: List[str] = []
......@@ -59,10 +59,10 @@ def default_auto_wrap_policy(
is_large = unwrapped_params >= min_num_params
if recurse:
# We should recurse if the module is big enough but not force_leaf_modulesed.
# We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, we should wrap but not the exclude list.
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
......@@ -71,6 +71,32 @@ default_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}
default_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore
def config_auto_wrap_policy(module: nn.Module, recurse: bool, unwrapped_params: int,) -> bool:
"""Config based policy function for :func:`auto_wrap`.
Return true for a module to be wrapped if it is already tagged with
a ``wrapper_config`` attribute.
Args:
module (nn.Module):
The module to be considered in this decision.
recurse (bool):
Indicate if this is called to make a decision on whether we
should recurse down a subgraph of the module structure.
If False, it means this function is called to make a decision
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
"""
if recurse:
# We should always recurse.
return True
else:
# If we are not recursing, determine if we should wrap.
return hasattr(module, "wrapper_config")
@contextlib.contextmanager
def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: Any) -> Generator[None, None, None]:
"""
......@@ -112,13 +138,20 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. An important
use case is annotating large layers that should be sharded (in-place) during
initialization, to avoid running out of system memory.
wrapped if inside of an :func:`enable_wrap` context manager. This allows
a module to be initialized both with and without a wrapper without code
change.
Both wrapper_cls and wrapper_config can be taken from 3 sources with
increasing priority:
1. ConfigAutoWrap's context
2. module.wrapper_config
3. wrap_overrides argument of this function
Usage::
with enable_wrap(**params):
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
......@@ -128,7 +161,11 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
the values provided by the :func:`enable_wrap` context
"""
if ConfigAutoWrap.in_autowrap_context:
wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
module_overrides = {}
if hasattr(module, "wrapper_config"):
module_overrides = module.wrapper_config
assert isinstance(module_overrides, dict)
wrap_overrides = {**ConfigAutoWrap.kwargs, **module_overrides, **wrap_overrides}
assert ConfigAutoWrap.wrapper_cls is not None
return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
return module
......@@ -141,6 +178,13 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **
children modules that meet the criteria given by :func:`auto_wrap_policy`. This
is useful for wrapping large complex layers.
.. note:: auto_wrap can only be applied to a module once because it
assumes none of the sub-modules is already wrapped and uses that
assumption to compute the wrapped vs. unwrapped parameters.
To get around this limitation, users can pre-assign ``wrapper_config``
attributes to the sub-modules they want to wrap (in multiple passes)
and then uses the ``config_auto_wrap_policy``.
.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
......
......@@ -111,3 +111,6 @@ class Module(Generic[T_co]):
# This is added torchgpipe
training: bool
# Added by auto_wrap.py.
wrapper_config: dict
......@@ -76,7 +76,6 @@ def _create_model(
if with_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
fsdp_config = {
"wrapper_cls": FSDP,
"mixed_precision": False,
"flatten_parameters": False,
"reshard_after_forward": False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment