Unverified Commit 3b8f445f authored by ruanslv's avatar ruanslv Committed by GitHub
Browse files

[fix] Add option to wrap root module in auto_wrap (#930)



* [fix] Add option to wrap root module in auto_wrap

* Fix unit-test comment

* adding a few more tests to make expected behavior clear

* move changes to wrap policy as suggested

* set default to false

* revert pre-commit change

* revert pre-commit change 2
Co-authored-by: default avatarRuan Silva <ruanrms@fb.com>
parent fae29959
......@@ -13,18 +13,20 @@ def default_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
# These are customizable for this default policy function.
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
skip_params_check_for_root: bool = False,
) -> bool:
"""Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`.
The first three parameters are used by :func:`auto_wrap`. If
The first four parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version
needs to at least accept the first three parameters and free
needs to at least accept the first four parameters and free
to do whatever you want in the function.
Args:
......@@ -37,6 +39,8 @@ def default_auto_wrap_policy(
on whether we should wrap the said module.
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
module_is_root (bool):
Indicates if current module is the root.
min_num_params (int):
Customizable policy input. It controls the size threshold
......@@ -45,6 +49,9 @@ def default_auto_wrap_policy(
keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping.
skip_params_check_for_root (bool):
If module_is_root is True, then this includes the root in
wrapping regardless of their number of unwrapped params.
"""
force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
......@@ -63,7 +70,9 @@ def default_auto_wrap_policy(
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
module, tuple(exclude_wrap_modules)
)
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
......@@ -75,6 +84,7 @@ def config_auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
) -> bool:
"""Config based policy function for :func:`auto_wrap`.
......@@ -92,6 +102,9 @@ def config_auto_wrap_policy(
unwrapped_params (int):
The number of parameters yet to be wrapped in this module.
Unused by this function.
module_is_root (bool):
Indicates if current module is the root.
Unused by this function.
"""
if recurse:
# We should always recurse.
......@@ -209,7 +222,9 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **
(default: wrap if > 100M parameters)
"""
if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(module, auto_wrap_policy=auto_wrap_policy, **kwargs)
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
)
return wrapped_module
return module
......@@ -258,7 +273,9 @@ class ConfigAutoWrap:
self.disable_autowrap_context()
@staticmethod
def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kwargs: Any) -> Tuple[nn.Module, int]:
def recursive_wrap(
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`.
......@@ -284,12 +301,12 @@ class ConfigAutoWrap:
num_params = sum([p.numel() for p in module.parameters()])
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, auto_wrap_policy=auto_wrap_policy, **kwargs
module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
......@@ -297,7 +314,9 @@ class ConfigAutoWrap:
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params
if auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder):
if auto_wrap_policy(
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
):
# Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params
else:
......
......@@ -48,19 +48,35 @@ class TestAutoWrap(unittest.TestCase):
"""
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
Root is not wrapped given there are not enough unwrapped params left and skip_params_check_for_root
is not set.
"""
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=60)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], FSDP)
assert isinstance(model[1].module[0], nn.Linear)
assert isinstance(model[1].module[1], nn.Linear)
def test_auto_wrap_skip_root_checks(self):
"""
Similar test as before but this time we set skip_params_check_for_root=True in the wrap policy.
So in this case the root is wrapped even without enough remaining unwrapped params.
"""
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=60, skip_params_check_for_root=True
)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear)
assert isinstance(model.module[2], FSDP)
assert isinstance(model.module[2].module[0], nn.Linear)
assert isinstance(model.module[2].module[1], nn.Linear)
assert isinstance(model.module[1], FSDP)
assert isinstance(model.module[1].module[0], nn.Linear)
assert isinstance(model.module[1].module[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment