Unverified Commit fe5aa604 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Adjust checkpointing of FP8 metadata for attention (#917)



* subclass DPA with BaseModule and test with test_gpt_checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test DPA only
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test save and load
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add hook in case core_attention._extra_state is missing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check named buffers in BaseModule; remove FP8 scratchpad override function; test FP8 for sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes: test size, interval in recipe, named_buffer loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move BaseModule from FusedAttention to DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d71fc946
...@@ -8,6 +8,7 @@ from contextlib import nullcontext ...@@ -8,6 +8,7 @@ from contextlib import nullcontext
import torch import torch
import pytest import pytest
import io
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
fp8_autocast, fp8_autocast,
...@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init, fp8_model_init,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
...@@ -86,6 +88,7 @@ model_configs = { ...@@ -86,6 +88,7 @@ model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12), "126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2), "small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3), "weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
} }
fp8_recipes = [ fp8_recipes = [
...@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
use_split_accumulator=False, use_split_accumulator=False,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=False,
)
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()
# call state_dict()
sd = block.state_dict()
# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")
# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state
# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)
# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
# load from checkpoint
block_new.load_state_dict(torch.load(path))
# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
...@@ -4092,7 +4092,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4092,7 +4092,7 @@ class FusedAttnFunc(torch.autograd.Function):
) )
class FusedAttention(TransformerEngineBaseModule): class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends: """Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"] 1. FusedAttnBackend["F16_max512_seqlen"]
...@@ -4159,21 +4159,24 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4159,21 +4159,24 @@ class FusedAttention(TransformerEngineBaseModule):
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
Temporarily remove fused_attention._extra_state as a missing key Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out or an unexpected key when loading TransformerEngine checkpoints.
this hook in TransformerEngine 2.0. Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0.
""" """
for key in incompatible_keys.missing_keys: for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key: if "fused_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key) incompatible_keys.missing_keys.remove(key)
for key in incompatible_keys.unexpected_keys:
if "fused_attention._extra_state" in key:
incompatible_keys.unexpected_keys.remove(key)
warnings.warn(
"fused_attention._extra_state is not loaded from checkpoint. Please map "
"FusedAttention's _extra_state to DotProductAttention's _extra_state."
)
self.register_load_state_dict_post_hook(remove_extra_states_check) self.register_load_state_dict_post_hook(remove_extra_states_check)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""Needs override."""
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
...@@ -4198,7 +4201,8 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4198,7 +4201,8 @@ class FusedAttention(TransformerEngineBaseModule):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
is_first_microbatch: Optional[bool] = None, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -4337,57 +4341,48 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -4337,57 +4341,48 @@ class FusedAttention(TransformerEngineBaseModule):
use_fused_attention=True, use_fused_attention=True,
) )
else: else:
with self.prepare_forward( with self.attention_dropout_ctx():
query_layer, is_first_microbatch, num_gemms=3, allow_non_contiguous=True if fp8:
) as query_layer: assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
with self.attention_dropout_ctx(): f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
forced_fp8_dpa = "" " is required for FP8 attention!"
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
query_layer,
key_layer,
value_layer,
qkv_dtype,
core_attention_bias,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
self.fp8_meta,
) )
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
query_layer,
key_layer,
value_layer,
qkv_dtype,
core_attention_bias,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
)
# ...hd -> ...(hd) # ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1) return output.view(*output.shape[:-2], -1)
class DotProductAttention(torch.nn.Module): class DotProductAttention(TransformerEngineBaseModule):
"""Allows the model to jointly attend to information from different """Allows the model to jointly attend to information from different
representation subspaces as described in the paper: representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
...@@ -4602,6 +4597,18 @@ class DotProductAttention(torch.nn.Module): ...@@ -4602,6 +4597,18 @@ class DotProductAttention(torch.nn.Module):
softmax_scale, **attn_kwargs, layer_number=layer_number softmax_scale, **attn_kwargs, layer_number=layer_number
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def _checkpointed_attention_forward( def _checkpointed_attention_forward(
self, self,
attention_func: Callable, attention_func: Callable,
...@@ -4805,480 +4812,562 @@ class DotProductAttention(torch.nn.Module): ...@@ -4805,480 +4812,562 @@ class DotProductAttention(torch.nn.Module):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
with self.prepare_forward(
query_layer,
is_first_microbatch,
num_gemms=3,
allow_non_contiguous=True,
) as query_layer:
if self.fp8:
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
assert forward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
] and backward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors." ), "DotProductAttention only supports CUDA tensors."
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if attn_mask_type is not None: if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size) window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
else: else:
attn_mask_type = attn_mask_type.replace(",", "_") attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding": if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal" attn_mask_type = "padding_causal"
assert (
attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert ( assert (
graph_safe_rng_available() attn_mask_type in AttnMaskTypes
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." ), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if window_size is None: if self.rng_states_tracker is not None and is_graph_capturing():
window_size = self.window_size assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if qkv_format is None: if window_size is None:
qkv_format = self.qkv_format window_size = self.window_size
if inference_params is not None: if qkv_format is None:
assert self.layer_number is not None, "Layer number must be set!" qkv_format = self.qkv_format
if qkv_format == "bshd": if inference_params is not None:
key_layer = key_layer.transpose(0, 1) assert self.layer_number is not None, "Layer number must be set!"
value_layer = value_layer.transpose(0, 1)
( if qkv_format == "bshd":
inference_key_memory, key_layer = key_layer.transpose(0, 1)
inference_value_memory, value_layer = value_layer.transpose(0, 1)
) = inference_params.key_value_memory_dict[self.layer_number]
batch_start = inference_params.batch_size_offset (
batch_end = batch_start + key_layer.size(1) inference_key_memory,
assert batch_end <= inference_key_memory.size(1) inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
sequence_start = inference_params.sequence_len_offset batch_start = inference_params.batch_size_offset
sequence_end = sequence_start + key_layer.size(0) batch_end = batch_start + key_layer.size(1)
assert sequence_end <= inference_key_memory.size(0) assert batch_end <= inference_key_memory.size(1)
# Copy keys and values into KV-cache sequence_start = inference_params.sequence_len_offset
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( sequence_end = sequence_start + key_layer.size(0)
key_layer assert sequence_end <= inference_key_memory.size(0)
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
if qkv_format == "bshd": # Copy keys and values into KV-cache
key_layer = key_layer.transpose(0, 1) inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer = value_layer.transpose(0, 1) key_layer
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
key_layer = key_layer.contiguous() if qkv_format == "bshd":
value_layer = value_layer.contiguous() key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
assert ( key_layer = key_layer.contiguous()
key_layer.shape[-2] == self.num_gqa_groups_per_partition value_layer = value_layer.contiguous()
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
assert qkv_format in [
"sbhd",
"bshd",
"thd",
], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
if qkv_format == "thd":
assert all(
len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert ( assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None key_layer.shape[-2] == self.num_gqa_groups_per_partition
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" and value_layer.shape[-2] == self.num_gqa_groups_per_partition
assert ( ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
cu_seqlens_q.shape == cu_seqlens_kv.shape assert qkv_format in [
and len(cu_seqlens_q.shape) == 1 "sbhd",
and len(cu_seqlens_kv.shape) == 1 "bshd",
), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!" "thd",
assert ( ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if qkv_format == "thd":
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all( assert all(
seqlens_q <= max_seqlen_q len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), """Sequence lengths indicated by cu_seqlens_q must be no greater than ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
the sequence dimention in 'query_layer'!""" assert (
if cu_seqlens_kv is not None: cu_seqlens_q is not None and cu_seqlens_kv is not None
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
and len(cu_seqlens_q.shape) == 1
and len(cu_seqlens_kv.shape) == 1
), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
assert (
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ["sbhd", "bshd"]:
assert all( assert all(
seqlens_kv <= max_seqlen_kv len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
the sequence dimention in 'key_layer' and 'value_layer'!""" if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimention in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
if ( if (
isinstance(query_layer, Float8Tensor) isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor) and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor) and isinstance(value_layer, Float8Tensor)
): ):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout( qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
) )
else: else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format query_layer, key_layer, value_layer, qkv_format=qkv_format
) )
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if use_unfused_attention and qkv_format == "thd":
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: ONNX export. # The priority for attention backends (subject to availability and clearing the filters)
if is_in_onnx_export_mode(): # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
if use_flash_attention: use_flash_attention = self.use_flash_attention
self.logger.debug("Disabling FlashAttention for ONNX mode") use_fused_attention = self.use_fused_attention
use_flash_attention = False use_unfused_attention = True
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Input type.
if use_flash_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# Filter: cross attention + causal mask. # The following section filters out some backends based on
# (in training mode) # certain asserts before executing the forward pass.
if (
use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
context_parallel = ( # Filter: QKV layout.
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 if use_unfused_attention and qkv_format == "thd":
) self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: sliding window attention. # Filter: ONNX export.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask. if is_in_onnx_export_mode():
if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention: if use_flash_attention:
self.logger.debug( self.logger.debug("Disabling FlashAttention for ONNX mode")
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
use_flash_attention = False use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Attention mask type. # Filter: Input type.
# attn_mask_type(s) | supported backends if use_flash_attention and (
# ------------------------------------------------ query_layer.dtype not in [torch.bfloat16, torch.float16]
# no_mask | All or key_layer.dtype not in [torch.bfloat16, torch.float16]
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention or value_layer.dtype not in [torch.bfloat16, torch.float16]
# causal | All or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
# padding + causal | FlashAttention, FusedAttention ):
# arbitrary | UnfusedDotProductAttention self.logger.debug(
# "Disabling FlashAttention due to unsupported QKV data types. "
if attn_mask_type == "arbitrary": "Supported: [torch.bfloat16, torch.float16]. "
if use_flash_attention: "Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
self.logger.debug("Disabling FlashAttention for arbitrary mask") query_layer.dtype,
use_flash_attention = False key_layer.dtype,
if use_fused_attention: value_layer.dtype,
self.logger.debug("Disabling FusedAttention for arbitrary mask") )
use_fused_attention = False use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
if ( # Filter: Execution type.
use_unfused_attention if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
and inference_params is None self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
and "causal" in attn_mask_type use_flash_attention = False
and max_seqlen_q != max_seqlen_kv if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
): self.logger.debug(
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd") "Disabling UnfusedDotProductAttention as it does not support FP8 execution."
use_unfused_attention = False )
use_unfused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# Filter: bias. # Filter: cross attention + causal mask.
global _alibi_cache # (in training mode)
if alibi_slopes is not None:
assert (
core_attention_bias_type == "alibi"
), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
if ( if (
_alibi_cache["_num_heads"] != query_layer.shape[-2] use_flash_attention
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q and inference_params is None
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv and _flash_attn_2_1_plus
or _alibi_cache["_alibi_slopes"] is None and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
): ):
_alibi_cache["_alibi_slopes_require_update"] = True self.logger.warning(
_alibi_cache["_alibi_bias_require_update"] = True "In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if use_flash_attention and ( context_parallel = (
core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
) )
):
if fu_core_attention_bias.requires_grad: # Filter: sliding window attention.
# remove this line when cuDNN adds bwd support for # UnfusedDotProductAttention can support SWA via arbitrary attention mask.
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] if window_size not in ((-1, -1), (-1, 0)):
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
use_flash_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
#
if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention: if (
fused_attention_backend = tex.get_fused_attn_backend( use_unfused_attention
( and inference_params is None
TE_DType[query_layer.dtype] and "causal" in attn_mask_type
if not isinstance(query_layer, Float8Tensor) and max_seqlen_q != max_seqlen_kv
else query_layer._fp8_dtype ):
), self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
( use_unfused_attention = False
TE_DType[key_layer.dtype]
if not isinstance(key_layer, Float8Tensor) # Filter: bias.
else key_layer._fp8_dtype global _alibi_cache
), if alibi_slopes is not None:
QKVLayout[qkv_layout], assert (
AttnBiasType[fu_core_attention_bias_type], core_attention_bias_type == "alibi"
AttnMaskType[attn_mask_type], ), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
self.attention_dropout, if self.layer_number == 1:
query_layer.shape[-2], # num_attn_heads _alibi_cache["_alibi_slopes_require_update"] = True
key_layer.shape[-2], # num_gqa_groups _alibi_cache["_alibi_bias_require_update"] = True
max_seqlen_q, if core_attention_bias_type == "alibi":
max_seqlen_kv, assert (
query_layer.shape[-1], # head_dim core_attention_bias is None
) ), "core_attention_bias must be None when core_attention_bias_type is alibi!"
# DPA does not support FP8; for FP8, use cpp_extensions modules directly if (
is_backend_avail = fused_attention_backend in [ _alibi_cache["_num_heads"] != query_layer.shape[-2]
FusedAttnBackend["F16_max512_seqlen"], or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
FusedAttnBackend["F16_arbitrary_seqlen"], or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
FusedAttnBackend["FP8"], or _alibi_cache["_alibi_slopes"] is None
] ):
use_fused_attention = ( _alibi_cache["_alibi_slopes_require_update"] = True
use_fused_attention _alibi_cache["_alibi_bias_require_update"] = True
and is_backend_avail
and ( if use_flash_attention and (
not context_parallel core_attention_bias_type not in ["no_bias", "alibi"]
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] or core_attention_bias is not None
):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if (
core_attention_bias_type == "alibi"
and use_fused_attention
and alibi_slopes is not None
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
) )
)
if ( if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_type == "post_scale_bias"
and ( and (
fu_core_attention_bias.shape[0] != 1 fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2] or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
) )
): ):
self.logger.debug( if fu_core_attention_bias.requires_grad:
"Disabling FusedAttention as no backend supports the provided input" # remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention:
q_type = TE_DType[query_layer.dtype]
kv_type = TE_DType[key_layer.dtype]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
if isinstance(query_layer, Float8Tensor) and isinstance(
key_layer, Float8Tensor
):
q_type = query_layer._fp8_dtype
kv_type = value_layer._fp8_dtype
else:
q_type = forward_dtype
kv_type = forward_dtype
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
) )
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"],
]
use_fused_attention = (
use_fused_attention
and is_backend_avail
and (
not context_parallel
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
)
)
if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
):
self.logger.debug(
"Disabling FusedAttention as no backend supports the provided input"
)
use_fused_attention = False
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False use_fused_attention = False
# Filter: determinism. # Select FusedAttention on sm90 and FlashAttention on others for performance
# backend | deterministic if (
# --------------------------------------------------------- use_flash_attention
# flash-attn v1 | yes and use_fused_attention
# flash-attn v2 | no and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
# FusedAttnBackend["F16_max512_seqlen"] | yes ):
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no if self.device_compute_capability == (9, 0):
# UnfusedDotProductAttention | yes self.logger.debug(
# "Disabling FlashAttention to give FusedAttention preference on Hopper+ "
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path "for performance reasons"
# on sm90 architectures. )
# use_flash_attention = False
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance run_config = {
if ( "compute_capability": "sm"
use_flash_attention + str(
and use_fused_attention (lambda x, y: x * 10 + y)(
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] self.device_compute_capability[0], self.device_compute_capability[1]
): )
if self.device_compute_capability == (9, 0): ),
self.logger.debug( "q_dtype": query_layer.dtype,
"Disabling FlashAttention to give FusedAttention preference on Hopper+ " "k_dtype": key_layer.dtype,
"for performance reasons" "v_dtype": value_layer.dtype,
) "q_shape": list(query_layer.shape),
use_flash_attention = False "k_shape": list(key_layer.shape),
"v_shape": list(value_layer.shape),
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": (
core_attention_bias.shape if core_attention_bias is not None else None
),
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
run_config = { if use_flash_attention:
"compute_capability": "sm" self.logger.info("Running with FlashAttention backend ")
+ str( self.logger.debug("Running with config=%s", run_config)
(lambda x, y: x * 10 + y)( if core_attention_bias_type == "alibi":
self.device_compute_capability[0], self.device_compute_capability[1] alibi_slopes, _ = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
)
return self.flash_attention(
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes=alibi_slopes,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
) )
),
"q_dtype": query_layer.dtype,
"k_dtype": key_layer.dtype,
"v_dtype": value_layer.dtype,
"q_shape": list(query_layer.shape),
"k_shape": list(key_layer.shape),
"v_shape": list(value_layer.shape),
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": core_attention_bias.shape if core_attention_bias is not None else None,
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
if use_flash_attention: if use_fused_attention:
self.logger.info("Running with FlashAttention backend ") self.logger.info(
self.logger.debug("Running with config=%s", run_config) "Running with FusedAttention backend (sub-backend %s)",
if core_attention_bias_type == "alibi": int(fused_attention_backend),
alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes
) )
return self.flash_attention( if self.fp8:
query_layer, self.logger.debug(
key_layer, "Running with fp8_recipe.fp8_mha=%s, "
value_layer, "fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
attention_mask=attention_mask, self.fp8_meta["recipe"].fp8_mha,
qkv_layout=qkv_layout, self.fp8_meta["recipe"].fp8_dpa,
cu_seqlens_q=cu_seqlens_q, forced_fp8_dpa,
cu_seqlens_kv=cu_seqlens_kv, int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
attn_mask_type=attn_mask_type, )
window_size=window_size, self.logger.debug("Running with config=%s", run_config)
alibi_slopes=alibi_slopes, if checkpoint_core_attention:
cp_group=self.cp_group, return self._checkpointed_attention_forward(
cp_global_ranks=self.cp_global_ranks, self.fused_attention,
cp_stream=self.cp_stream, query_layer,
max_seqlen_q=max_seqlen_q, key_layer,
max_seqlen_kv=max_seqlen_kv, value_layer,
) qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
if use_fused_attention: cu_seqlens_kv=cu_seqlens_kv,
self.logger.info( seq_offsets_q=seq_offsets_q,
"Running with FusedAttention backend (sub-backend %s)", int(fused_attention_backend) seq_offsets_k=seq_offsets_k,
) seq_offsets_v=seq_offsets_v,
self.logger.debug("Running with config=%s", run_config) seq_offsets_o=seq_offsets_o,
if checkpoint_core_attention: max_seqlen_q=max_seqlen_q,
return self._checkpointed_attention_forward( max_seqlen_kv=max_seqlen_kv,
self.fused_attention, attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
return self.fused_attention(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -5300,51 +5389,41 @@ class DotProductAttention(torch.nn.Module): ...@@ -5300,51 +5389,41 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
) )
return self.fused_attention(
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch,
)
assert ( assert (
not context_parallel not context_parallel
), "Context parallelism is only implemented with Flash Attention and Fused Attention!" ), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
warnings.warn( warnings.warn(
"Attention activation Offloading is only implemented" "Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!" "with Flash Attention and Fused Attention!"
) )
if use_unfused_attention: if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend") self.logger.info("Running with UnfusedDotProductAttention backend")
self.logger.debug("Running with config=%s", run_config) self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
)
return self.unfused_attention(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -5357,21 +5436,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -5357,21 +5436,8 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
) )
return self.unfused_attention(
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
)
raise Exception("No dot product attention support for the provided inputs!") raise Exception("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
......
...@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, tuple, list)): if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. " "Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
) )
for name, buf in self.named_buffers():
if buf is not None:
assert dtype == buf.dtype, (
"Data types for buffers must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
)
self.activation_dtype = dtype self.activation_dtype = dtype
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment