Unverified Commit fe80ca06 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fixed assert on primary Fp8 weights in `prepare_te_modules_for_fsdp()` (#916)



restricted fsdp asserts on primary fp8 weights to TE modules
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 236a2030
...@@ -940,13 +940,14 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -940,13 +940,14 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
FSDP-wrapped root module that may contain FSDP-wrapped TE modules. FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
""" """
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
assert not fsdp_root.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
# If the root module is a TE module, inject FSDP information into it # If the root module is a TE module, inject FSDP information into it
if _is_te_module(fsdp_root.module): if _is_te_module(fsdp_root.module):
if hasattr(fsdp_root, "primary_weights_in_fp8"):
assert not fsdp_root.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
root_state = _get_module_fsdp_state(fsdp_root) root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState." assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group) setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
...@@ -954,11 +955,12 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -954,11 +955,12 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
for state, fsdp_module in zip(fsdp_states, fsdp_modules): for state, fsdp_module in zip(fsdp_states, fsdp_modules):
assert not fsdp_module.module.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
if _is_te_module(fsdp_module.module): if _is_te_module(fsdp_module.module):
if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
assert not fsdp_module.module.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context."
)
setattr(fsdp_module.module, "fsdp_group", state.process_group) setattr(fsdp_module.module, "fsdp_group", state.process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment