Unverified Commit fe5aa604 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Adjust checkpointing of FP8 metadata for attention (#917)



* subclass DPA with BaseModule and test with test_gpt_checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test DPA only
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test save and load
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add hook in case core_attention._extra_state is missing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check named buffers in BaseModule; remove FP8 scratchpad override function; test FP8 for sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes: test size, interval in recipe, named_buffer loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move BaseModule from FusedAttention to DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d71fc946
...@@ -8,6 +8,7 @@ from contextlib import nullcontext ...@@ -8,6 +8,7 @@ from contextlib import nullcontext
import torch import torch
import pytest import pytest
import io
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
fp8_autocast, fp8_autocast,
...@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init, fp8_model_init,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
...@@ -86,6 +88,7 @@ model_configs = { ...@@ -86,6 +88,7 @@ model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12), "126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2), "small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3), "weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
} }
fp8_recipes = [ fp8_recipes = [
...@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
use_split_accumulator=False, use_split_accumulator=False,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=False,
)
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()
# call state_dict()
sd = block.state_dict()
# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")
# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state
# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)
# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
# load from checkpoint
block_new.load_state_dict(torch.load(path))
# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
...@@ -4092,7 +4092,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4092,7 +4092,7 @@ class FusedAttnFunc(torch.autograd.Function):
) )
class FusedAttention(TransformerEngineBaseModule): class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends: """Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"] 1. FusedAttnBackend["F16_max512_seqlen"]
...@@ -4159,21 +4159,24 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4159,21 +4159,24 @@ class FusedAttention(TransformerEngineBaseModule):
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
Temporarily remove fused_attention._extra_state as a missing key Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out or an unexpected key when loading TransformerEngine checkpoints.
this hook in TransformerEngine 2.0. Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0.
""" """
for key in incompatible_keys.missing_keys: for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key: if "fused_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key) incompatible_keys.missing_keys.remove(key)
for key in incompatible_keys.unexpected_keys:
if "fused_attention._extra_state" in key:
incompatible_keys.unexpected_keys.remove(key)
warnings.warn(
"fused_attention._extra_state is not loaded from checkpoint. Please map "
"FusedAttention's _extra_state to DotProductAttention's _extra_state."
)
self.register_load_state_dict_post_hook(remove_extra_states_check) self.register_load_state_dict_post_hook(remove_extra_states_check)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""Needs override."""
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
...@@ -4198,7 +4201,8 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4198,7 +4201,8 @@ class FusedAttention(TransformerEngineBaseModule):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
is_first_microbatch: Optional[bool] = None, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -4337,24 +4341,15 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4337,24 +4341,15 @@ class FusedAttention(TransformerEngineBaseModule):
use_fused_attention=True, use_fused_attention=True,
) )
else: else:
with self.prepare_forward(
query_layer, is_first_microbatch, num_gemms=3, allow_non_contiguous=True
) as query_layer:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
forced_fp8_dpa = "" if fp8:
if self.fp8_meta["recipe"].fp8_mha: assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
if not self.fp8_meta["recipe"].fp8_dpa: f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
self.fp8_meta["recipe"].fp8_dpa = True " is required for FP8 attention!"
forced_fp8_dpa = " (forced)"
if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
) )
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
output = FusedAttnFunc.apply( output = FusedAttnFunc.apply(
self.training, self.training,
max_seqlen_q, max_seqlen_q,
...@@ -4379,15 +4374,15 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4379,15 +4374,15 @@ class FusedAttention(TransformerEngineBaseModule):
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8,
self.fp8_meta, fp8_meta,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1) return output.view(*output.shape[:-2], -1)
class DotProductAttention(torch.nn.Module): class DotProductAttention(TransformerEngineBaseModule):
"""Allows the model to jointly attend to information from different """Allows the model to jointly attend to information from different
representation subspaces as described in the paper: representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
...@@ -4602,6 +4597,18 @@ class DotProductAttention(torch.nn.Module): ...@@ -4602,6 +4597,18 @@ class DotProductAttention(torch.nn.Module):
softmax_scale, **attn_kwargs, layer_number=layer_number softmax_scale, **attn_kwargs, layer_number=layer_number
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def _checkpointed_attention_forward( def _checkpointed_attention_forward(
self, self,
attention_func: Callable, attention_func: Callable,
...@@ -4805,6 +4812,30 @@ class DotProductAttention(torch.nn.Module): ...@@ -4805,6 +4812,30 @@ class DotProductAttention(torch.nn.Module):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
with self.prepare_forward(
query_layer,
is_first_microbatch,
num_gemms=3,
allow_non_contiguous=True,
) as query_layer:
if self.fp8:
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
assert forward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
] and backward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
...@@ -5000,6 +5031,16 @@ class DotProductAttention(torch.nn.Module): ...@@ -5000,6 +5031,16 @@ class DotProductAttention(torch.nn.Module):
) )
use_fused_attention = False use_fused_attention = False
# Filter: Execution type.
if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
use_flash_attention = False
if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug(
"Disabling UnfusedDotProductAttention as it does not support FP8 execution."
)
use_unfused_attention = False
# Filter: Device and dimensions. # Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90 # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0 # FAv2 requires head_dim % 8 == 0
...@@ -5104,14 +5145,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -5104,14 +5145,19 @@ class DotProductAttention(torch.nn.Module):
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
if use_flash_attention and ( if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias is not None
): ):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias") self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None: if (
core_attention_bias_type == "alibi"
and use_fused_attention
and alibi_slopes is not None
):
fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi( _, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], query_layer.shape[-2],
...@@ -5138,17 +5184,20 @@ class DotProductAttention(torch.nn.Module): ...@@ -5138,17 +5184,20 @@ class DotProductAttention(torch.nn.Module):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention: if use_fused_attention:
q_type = TE_DType[query_layer.dtype]
kv_type = TE_DType[key_layer.dtype]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
if isinstance(query_layer, Float8Tensor) and isinstance(
key_layer, Float8Tensor
):
q_type = query_layer._fp8_dtype
kv_type = value_layer._fp8_dtype
else:
q_type = forward_dtype
kv_type = forward_dtype
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
( q_type,
TE_DType[query_layer.dtype] kv_type,
if not isinstance(query_layer, Float8Tensor)
else query_layer._fp8_dtype
),
(
TE_DType[key_layer.dtype]
if not isinstance(key_layer, Float8Tensor)
else key_layer._fp8_dtype
),
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
...@@ -5237,7 +5286,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -5237,7 +5286,9 @@ class DotProductAttention(torch.nn.Module):
"qkv_layout": qkv_layout, "qkv_layout": qkv_layout,
"mask_type": attn_mask_type, "mask_type": attn_mask_type,
"bias_type": core_attention_bias_type, "bias_type": core_attention_bias_type,
"bias_shape": core_attention_bias.shape if core_attention_bias is not None else None, "bias_shape": (
core_attention_bias.shape if core_attention_bias is not None else None
),
"dropout": self.attention_dropout, "dropout": self.attention_dropout,
"context_parallel": context_parallel, "context_parallel": context_parallel,
"is_training": self.training, "is_training": self.training,
...@@ -5251,7 +5302,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -5251,7 +5302,10 @@ class DotProductAttention(torch.nn.Module):
self.logger.debug("Running with config=%s", run_config) self.logger.debug("Running with config=%s", run_config)
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi( alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
) )
return self.flash_attention( return self.flash_attention(
query_layer, query_layer,
...@@ -5273,7 +5327,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -5273,7 +5327,17 @@ class DotProductAttention(torch.nn.Module):
if use_fused_attention: if use_fused_attention:
self.logger.info( self.logger.info(
"Running with FusedAttention backend (sub-backend %s)", int(fused_attention_backend) "Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
if self.fp8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
) )
self.logger.debug("Running with config=%s", run_config) self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention: if checkpoint_core_attention:
...@@ -5300,7 +5364,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -5300,7 +5364,8 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -5324,7 +5389,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -5324,7 +5389,8 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
) )
assert ( assert (
......
...@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, tuple, list)): if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. " "Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
) )
for name, buf in self.named_buffers():
if buf is not None:
assert dtype == buf.dtype, (
"Data types for buffers must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
)
self.activation_dtype = dtype self.activation_dtype = dtype
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment