Unverified Commit 2d875521 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add documentation for FP8 attention checkpointing (#1223)



* add extra_state change description for different TE versions
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FAQ page
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FAQ page
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix extra_state tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5b89f1ad
..
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Frequently Asked Questions (FAQ)
================================
FP8 checkpoint compatibility
----------------------------
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.
Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.
.. code-block:: python
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.
.. list-table::
* - **Version: <= 1.5**
- Saves no FP8 metadata since FP8 attention is not supported
- Loading behavior for checkpoints created by the following versions:
:<= 1.5: Loads no FP8 metadata
:> 1.5: Error: unexpected key
* - **Version: 1.6, 1.7**
- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: Loads FP8 metadata from checkpoint
:>= 1.8: Error: unexpected key
* - **Version: >=1.8, <= 1.11**
- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by
.. code-block:: python
>>> state_dict["core_attention._extra_state"] = \
state_dict["core_attention.fused_attention._extra_state"]
>>> del state_dict["core_attention.fused_attention._extra_state"]
:>= 1.8: Loads FP8 metadata from checkpoint
* - **Version: >=1.12**
- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:>= 1.6: Loads FP8 metadata from checkpoint
...@@ -30,6 +30,7 @@ Transformer Engine documentation ...@@ -30,6 +30,7 @@ Transformer Engine documentation
installation installation
examples/quickstart.ipynb examples/quickstart.ipynb
faq
.. toctree:: .. toctree::
:hidden: :hidden:
......
...@@ -9,6 +9,7 @@ from contextlib import nullcontext ...@@ -9,6 +9,7 @@ from contextlib import nullcontext
import torch import torch
import pytest import pytest
import io import io
import os
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
fp8_autocast, fp8_autocast,
...@@ -42,6 +43,7 @@ from transformer_engine.pytorch.cpp_extensions import ( ...@@ -42,6 +43,7 @@ from transformer_engine.pytorch.cpp_extensions import (
) )
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta from test_onnx_export import create_meta
from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -1004,20 +1006,50 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -1004,20 +1006,50 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype): def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model] config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
amax_history_len=1, amax_history_len=1,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
fp8_dpa=True, fp8_dpa=fp8_enabled,
fp8_mha=False, fp8_mha=False,
) )
reset_rng_states()
hidden_states = torch.randn( hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
...@@ -1025,63 +1057,74 @@ def test_sanity_attention_extra_state(model, dtype): ...@@ -1025,63 +1057,74 @@ def test_sanity_attention_extra_state(model, dtype):
requires_grad=True, requires_grad=True,
) )
with fp8_model_init(enabled=True): def get_model(dtype, config):
block = TransformerLayer( sigma = 0.023
config.hidden_size, init_method = init_method_normal(sigma)
4 * config.hidden_size, output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
config.num_attention_heads,
fuse_qkv_params=True, with fp8_model_init(enabled=fp8_enabled):
params_dtype=dtype, block = TransformerLayer(
device="cuda", config.hidden_size,
) 4 * config.hidden_size,
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): config.num_attention_heads,
output = block(hidden_states, is_first_microbatch=True) init_method=init_method,
loss = output.sum() output_layer_init_method=output_layer_init_method,
loss.backward() fuse_qkv_params=True,
params_dtype=dtype,
# call state_dict() device="cuda",
sd = block.state_dict() )
return block
# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"] block = get_model(dtype, config)
attn_extra_state.seek(0) for i in range(steps // 2):
attn_extra_state = torch.load(attn_extra_state, map_location="cuda") with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
# add random core_attention.fused_attention._extra_state loss = output.sum()
# it should not be loaded or cause any 'unexpected key' errors loss.backward()
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO() if checkpoint:
torch.save(random_state, fused_attn_extra_state) sd = block.state_dict()
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
# save checkpoint "self_attention.core_attention._extra_state"
path = "./checkpoint.pt" ]
torch.save(sd, path) del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
# reinit the model
del block param_grads = []
with fp8_model_init(enabled=True): for p in block.parameters():
block_new = TransformerLayer( if p.requires_grad:
config.hidden_size, param_grads.append(p.grad.clone())
4 * config.hidden_size,
config.num_attention_heads, _cpu_rng_state_new = torch.get_rng_state()
fuse_qkv_params=True, _cuda_rng_state_new = torch.cuda.get_rng_state()
params_dtype=dtype,
device="cuda", del block
) block = get_model(dtype, config)
FP8GlobalStateManager.reset() block.load_state_dict(torch.load(path))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
# load from checkpoint return outputs
block_new.load_state_dict(torch.load(path))
# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
...@@ -6790,10 +6790,10 @@ class FusedAttention(torch.nn.Module): ...@@ -6790,10 +6790,10 @@ class FusedAttention(torch.nn.Module):
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
Temporarily remove fused_attention._extra_state as a missing key Temporarily remove fused_attention._extra_state as a missing key
or an unexpected key when loading TransformerEngine checkpoints. or an unexpected key when loading Transformer Engine checkpoints.
Please store FP8 metadata as DotProductAttention's _extra_state, Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0. phased out in Transformer Engine 2.0.
""" """
for key in incompatible_keys.missing_keys: for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key: if "fused_attention._extra_state" in key:
...@@ -7023,6 +7023,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7023,6 +7023,13 @@ class DotProductAttention(TransformerEngineBaseModule):
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
Parameters Parameters
---------- ----------
num_attention_heads : int num_attention_heads : int
...@@ -7051,7 +7058,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7051,7 +7058,7 @@ class DotProductAttention(TransformerEngineBaseModule):
e.g. a different mask for training and inference. e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied. 1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in 2. For "`causal`", "`causal_bottom_right`", or the causal mask in
"`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
calculates and applies an upper triangular mask to the softmax input. calculates and applies an upper triangular mask to the softmax input.
No user input is needed. Causal masks without the "`bottom_right`" appendix align No user input is needed. Causal masks without the "`bottom_right`" appendix align
the diagonal line to the top left corner of the softmax matrix. With the diagonal line to the top left corner of the softmax matrix. With
...@@ -7264,8 +7271,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7264,8 +7271,8 @@ class DotProductAttention(TransformerEngineBaseModule):
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
Temporarily remove core_attention._extra_state as a missing key Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out when loading older Transformer Engine checkpoints. Will phase out
this hook in TransformerEngine 2.0. this hook in Transformer Engine 2.0.
""" """
for key in incompatible_keys.missing_keys: for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key: if "core_attention._extra_state" in key:
...@@ -7273,6 +7280,28 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7273,6 +7280,28 @@ class DotProductAttention(TransformerEngineBaseModule):
self.register_load_state_dict_post_hook(remove_extra_states_check) self.register_load_state_dict_post_hook(remove_extra_states_check)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""
This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
`core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
<https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
"""
fused_attn_key = False
dot_product_attn_key = False
for k in state_dict.keys():
if "core_attention.fused_attention._extra_state" in k:
fused_attn_key = True
if "core_attention._extra_state" in k:
dot_product_attn_key = True
if fused_attn_key and not dot_product_attn_key:
prefix = prefix + "fused_attention."
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def _checkpointed_attention_forward( def _checkpointed_attention_forward(
self, self,
attention_func: Callable, attention_func: Callable,
...@@ -7382,14 +7411,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7382,14 +7411,14 @@ class DotProductAttention(TransformerEngineBaseModule):
Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`, Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend, and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
and FusedAttention backend if applicable, to use. TransformerEngine prioritizes and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
FlashAttention over FusedAttention and over UnfusedDotProductAttention. FlashAttention over FusedAttention and over UnfusedDotProductAttention.
If FusedAttention is being used, users can also choose to switch to flash-attn's If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1` implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of (default: 0), because of the performance differences between various versions of
flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT` flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
optimizations in FusedAttention. When unset, TransformerEngine determines the code path optimizations in FusedAttention. When unset, Transformer Engine determines the code path
based on its internal logic. These optimizations trade memory for performance based on its internal logic. These optimizations trade memory for performance
and should be used with care. and should be used with care.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment