Unverified Commit fe5aa604 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Adjust checkpointing of FP8 metadata for attention (#917)



* subclass DPA with BaseModule and test with test_gpt_checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test DPA only
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test save and load
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add hook in case core_attention._extra_state is missing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check named buffers in BaseModule; remove FP8 scratchpad override function; test FP8 for sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes: test size, interval in recipe, named_buffer loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move BaseModule from FusedAttention to DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d71fc946
......@@ -8,6 +8,7 @@ from contextlib import nullcontext
import torch
import pytest
import io
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
......@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
......@@ -86,6 +88,7 @@ model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
}
fp8_recipes = [
......@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
use_split_accumulator=False,
)
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=False,
)
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()
# call state_dict()
sd = block.state_dict()
# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")
# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state
# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)
# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
# load from checkpoint
block_new.load_state_dict(torch.load(path))
# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
This diff is collapsed.
......@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, tuple, list)):
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
......@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
for name, buf in self.named_buffers():
if buf is not None:
assert dtype == buf.dtype, (
"Data types for buffers must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
)
self.activation_dtype = dtype
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment