Unverified Commit 3b8f445f authored by ruanslv's avatar ruanslv Committed by GitHub
Browse files

[fix] Add option to wrap root module in auto_wrap (#930)



* [fix] Add option to wrap root module in auto_wrap

* Fix unit-test comment

* adding a few more tests to make expected behavior clear

* move changes to wrap policy as suggested

* set default to false

* revert pre-commit change

* revert pre-commit change 2
Co-authored-by: default avatarRuan Silva <ruanrms@fb.com>
parent fae29959
...@@ -13,18 +13,20 @@ def default_auto_wrap_policy( ...@@ -13,18 +13,20 @@ def default_auto_wrap_policy(
module: nn.Module, module: nn.Module,
recurse: bool, recurse: bool,
unwrapped_params: int, unwrapped_params: int,
module_is_root: bool,
# These are customizable for this default policy function. # These are customizable for this default policy function.
min_num_params: int = int(1e8), min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
skip_params_check_for_root: bool = False,
) -> bool: ) -> bool:
"""Default policy function for :func:`auto_wrap`. """Default policy function for :func:`auto_wrap`.
Return if a module should be wrapped during :func:`auto_wrap`. Return if a module should be wrapped during :func:`auto_wrap`.
The first three parameters are used by :func:`auto_wrap`. If The first four parameters are used by :func:`auto_wrap`. If
you write a custom version of this policy function, your version you write a custom version of this policy function, your version
needs to at least accept the first three parameters and free needs to at least accept the first four parameters and free
to do whatever you want in the function. to do whatever you want in the function.
Args: Args:
...@@ -37,6 +39,8 @@ def default_auto_wrap_policy( ...@@ -37,6 +39,8 @@ def default_auto_wrap_policy(
on whether we should wrap the said module. on whether we should wrap the said module.
unwrapped_params (int): unwrapped_params (int):
The number of parameters yet to be wrapped in this module. The number of parameters yet to be wrapped in this module.
module_is_root (bool):
Indicates if current module is the root.
min_num_params (int): min_num_params (int):
Customizable policy input. It controls the size threshold Customizable policy input. It controls the size threshold
...@@ -45,6 +49,9 @@ def default_auto_wrap_policy( ...@@ -45,6 +49,9 @@ def default_auto_wrap_policy(
keep as leaves, i.e., their children will never be wrapped. keep as leaves, i.e., their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]): exclude_wrap_modules (Set[Type[nn.Module]]):
Customizable set of module types to be excluded in wrapping. Customizable set of module types to be excluded in wrapping.
skip_params_check_for_root (bool):
If module_is_root is True, then this includes the root in
wrapping regardless of their number of unwrapped params.
""" """
force_leaf_modules = ( force_leaf_modules = (
default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore default_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore
...@@ -63,7 +70,9 @@ def default_auto_wrap_policy( ...@@ -63,7 +70,9 @@ def default_auto_wrap_policy(
return is_large and not isinstance(module, tuple(force_leaf_modules)) return is_large and not isinstance(module, tuple(force_leaf_modules))
else: else:
# If we are not recursing, determine if we should wrap. # If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules)) return ((module_is_root and skip_params_check_for_root) or is_large) and not isinstance(
module, tuple(exclude_wrap_modules)
)
# Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported. # Set those defaults to the default_auto_wrap_policy function. Make them easy to be imported.
...@@ -75,6 +84,7 @@ def config_auto_wrap_policy( ...@@ -75,6 +84,7 @@ def config_auto_wrap_policy(
module: nn.Module, module: nn.Module,
recurse: bool, recurse: bool,
unwrapped_params: int, unwrapped_params: int,
module_is_root: bool,
) -> bool: ) -> bool:
"""Config based policy function for :func:`auto_wrap`. """Config based policy function for :func:`auto_wrap`.
...@@ -92,6 +102,9 @@ def config_auto_wrap_policy( ...@@ -92,6 +102,9 @@ def config_auto_wrap_policy(
unwrapped_params (int): unwrapped_params (int):
The number of parameters yet to be wrapped in this module. The number of parameters yet to be wrapped in this module.
Unused by this function. Unused by this function.
module_is_root (bool):
Indicates if current module is the root.
Unused by this function.
""" """
if recurse: if recurse:
# We should always recurse. # We should always recurse.
...@@ -209,7 +222,9 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, ** ...@@ -209,7 +222,9 @@ def auto_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable] = None, **
(default: wrap if > 100M parameters) (default: wrap if > 100M parameters)
""" """
if ConfigAutoWrap.in_autowrap_context: if ConfigAutoWrap.in_autowrap_context:
wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(module, auto_wrap_policy=auto_wrap_policy, **kwargs) wrapped_module, remainder = ConfigAutoWrap.recursive_wrap(
module, auto_wrap_policy=auto_wrap_policy, module_is_root=True, **kwargs
)
return wrapped_module return wrapped_module
return module return module
...@@ -258,7 +273,9 @@ class ConfigAutoWrap: ...@@ -258,7 +273,9 @@ class ConfigAutoWrap:
self.disable_autowrap_context() self.disable_autowrap_context()
@staticmethod @staticmethod
def recursive_wrap(module: nn.Module, auto_wrap_policy: Optional[Callable], **kwargs: Any) -> Tuple[nn.Module, int]: def recursive_wrap(
module: nn.Module, auto_wrap_policy: Optional[Callable], module_is_root: bool, **kwargs: Any
) -> Tuple[nn.Module, int]:
""" """
Automatically wrap child modules of *module* that meet the given Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`. criteria with :func:`auto_wrap`.
...@@ -284,12 +301,12 @@ class ConfigAutoWrap: ...@@ -284,12 +301,12 @@ class ConfigAutoWrap:
num_params = sum([p.numel() for p in module.parameters()]) num_params = sum([p.numel() for p in module.parameters()])
assert auto_wrap_policy is not None assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params, module_is_root=module_is_root):
total_wrapped_params = 0 total_wrapped_params = 0
# Iterate through the children, recursively wrap if necessary # Iterate through the children, recursively wrap if necessary
for name, child in module.named_children(): for name, child in module.named_children():
wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap( wrapped_child, num_wrapped_params = ConfigAutoWrap.recursive_wrap(
module=child, auto_wrap_policy=auto_wrap_policy, **kwargs module=child, auto_wrap_policy=auto_wrap_policy, module_is_root=False, **kwargs
) )
setattr(module, name, wrapped_child) setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped # Keep track of how many parameters have been wrapped
...@@ -297,7 +314,9 @@ class ConfigAutoWrap: ...@@ -297,7 +314,9 @@ class ConfigAutoWrap:
# decide if we need to wrap the current module, # decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap # since the left over parameters exceed the number of params to wrap
remainder = num_params - total_wrapped_params remainder = num_params - total_wrapped_params
if auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder): if auto_wrap_policy(
module=module, recurse=False, unwrapped_params=remainder, module_is_root=module_is_root
):
# Leaf node or final wrapping of the remainder both happen here. # Leaf node or final wrapping of the remainder both happen here.
return wrap(module, **kwargs), num_params return wrap(module, **kwargs), num_params
else: else:
......
...@@ -48,19 +48,35 @@ class TestAutoWrap(unittest.TestCase): ...@@ -48,19 +48,35 @@ class TestAutoWrap(unittest.TestCase):
""" """
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
Root is not wrapped given there are not enough unwrapped params left and skip_params_check_for_root
is not set.
""" """
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential( sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=60)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], nn.Linear)
assert isinstance(model[1], FSDP)
assert isinstance(model[1].module[0], nn.Linear)
assert isinstance(model[1].module[1], nn.Linear)
def test_auto_wrap_skip_root_checks(self):
"""
Similar test as before but this time we set skip_params_check_for_root=True in the wrap policy.
So in this case the root is wrapped even without enough remaining unwrapped params.
"""
with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group, flatten_parameters=False):
sequential = nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)))
my_auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=60, skip_params_check_for_root=True
) )
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
assert isinstance(model, FSDP) assert isinstance(model, FSDP)
assert isinstance(model.module[0], nn.Linear) assert isinstance(model.module[0], nn.Linear)
assert isinstance(model.module[1], nn.Linear) assert isinstance(model.module[1], FSDP)
assert isinstance(model.module[2], FSDP) assert isinstance(model.module[1].module[0], nn.Linear)
assert isinstance(model.module[2].module[0], nn.Linear) assert isinstance(model.module[1].module[1], nn.Linear)
assert isinstance(model.module[2].module[1], nn.Linear)
def test_auto_wrap_preset_exclude_wrap(self): def test_auto_wrap_preset_exclude_wrap(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment