Unverified Commit 6459fd85 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch] Miscellanous fixes for FP8 DPA module (#804)



* initialize tp_group for FP8 DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuDNN version in unit tests for cuDNN v9
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove test and redundant implementation from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove warning message and replace with docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move core_attention.fused_attention._extra_state to core_attention._extra_state
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify post_state_dict_hooks between FU and DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add temporary test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove the test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable pylint self arg for hook which is required by hook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent eed4dfc6
......@@ -70,7 +70,8 @@ def reset_global_fp8_state():
def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major, encoded_version = divmod(encoded_version, 1000)
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
......
......@@ -2929,6 +2929,17 @@ class FusedAttention(TransformerEngineBaseModule):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if 'fused_attention._extra_state' in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
......@@ -3282,6 +3293,7 @@ class DotProductAttention(torch.nn.Module):
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment