Unverified Commit fe5aa604 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Adjust checkpointing of FP8 metadata for attention (#917)



* subclass DPA with BaseModule and test with test_gpt_checkpointing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test DPA only
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test save and load
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add hook in case core_attention._extra_state is missing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check named buffers in BaseModule; remove FP8 scratchpad override function; test FP8 for sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes: test size, interval in recipe, named_buffer loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move BaseModule from FusedAttention to DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d71fc946
......@@ -8,6 +8,7 @@ from contextlib import nullcontext
import torch
import pytest
import io
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
......@@ -15,6 +16,7 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
......@@ -86,6 +88,7 @@ model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
}
fp8_recipes = [
......@@ -997,3 +1000,86 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
use_split_accumulator=False,
)
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=False,
)
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()
# call state_dict()
sd = block.state_dict()
# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")
# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state
# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)
# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
# load from checkpoint
block_new.load_state_dict(torch.load(path))
# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
......@@ -4092,7 +4092,7 @@ class FusedAttnFunc(torch.autograd.Function):
)
class FusedAttention(TransformerEngineBaseModule):
class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"]
......@@ -4159,21 +4159,24 @@ class FusedAttention(TransformerEngineBaseModule):
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
or an unexpected key when loading TransformerEngine checkpoints.
Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)
for key in incompatible_keys.unexpected_keys:
if "fused_attention._extra_state" in key:
incompatible_keys.unexpected_keys.remove(key)
warnings.warn(
"fused_attention._extra_state is not loaded from checkpoint. Please map "
"FusedAttention's _extra_state to DotProductAttention's _extra_state."
)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""Needs override."""
@no_torch_dynamo()
def forward(
self,
......@@ -4198,7 +4201,8 @@ class FusedAttention(TransformerEngineBaseModule):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
is_first_microbatch: Optional[bool] = None,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -4337,57 +4341,48 @@ class FusedAttention(TransformerEngineBaseModule):
use_fused_attention=True,
)
else:
with self.prepare_forward(
query_layer, is_first_microbatch, num_gemms=3, allow_non_contiguous=True
) as query_layer:
with self.attention_dropout_ctx():
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
query_layer,
key_layer,
value_layer,
qkv_dtype,
core_attention_bias,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
self.fp8_meta,
with self.attention_dropout_ctx():
if fp8:
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
cu_seqlens_kv,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
query_layer,
key_layer,
value_layer,
qkv_dtype,
core_attention_bias,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
fp8,
fp8_meta,
)
# ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1)
class DotProductAttention(torch.nn.Module):
class DotProductAttention(TransformerEngineBaseModule):
"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
......@@ -4602,6 +4597,18 @@ class DotProductAttention(torch.nn.Module):
softmax_scale, **attn_kwargs, layer_number=layer_number
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def _checkpointed_attention_forward(
self,
attention_func: Callable,
......@@ -4805,480 +4812,562 @@ class DotProductAttention(torch.nn.Module):
first microbatch (since it is the first gradient being
produced)
"""
with self.prepare_forward(
query_layer,
is_first_microbatch,
num_gemms=3,
allow_non_contiguous=True,
) as query_layer:
if self.fp8:
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
assert forward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
] and backward_dtype in [
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors."
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors."
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
assert (
attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if attn_mask_type is not None:
window_size = check_set_window_size(attn_mask_type, window_size)
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
attn_mask_type = "padding_causal"
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
if window_size is None:
window_size = self.window_size
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if qkv_format is None:
qkv_format = self.qkv_format
if window_size is None:
window_size = self.window_size
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
if qkv_format is None:
qkv_format = self.qkv_format
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
# Copy keys and values into KV-cache
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
key_layer
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
# Copy keys and values into KV-cache
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
key_layer
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
assert (
key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
assert qkv_format in [
"sbhd",
"bshd",
"thd",
], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if qkv_format == "thd":
assert all(
len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
and len(cu_seqlens_q.shape) == 1
and len(cu_seqlens_kv.shape) == 1
), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
assert (
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
assert qkv_format in [
"sbhd",
"bshd",
"thd",
], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
if qkv_format == "thd":
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimention in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
and len(cu_seqlens_q.shape) == 1
and len(cu_seqlens_kv.shape) == 1
), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!"
assert (
cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
if qkv_format in ["sbhd", "bshd"]:
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0])
if qkv_format == "bshd":
max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimention in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
if (
isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: QKV layout.
if use_unfused_attention and qkv_format == "thd":
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if (
isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
# Filter: ONNX export.
if is_in_onnx_export_mode():
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for ONNX mode")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Input type.
if use_flash_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
use_unfused_attention = True
# Filter: cross attention + causal mask.
# (in training mode)
if (
use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
# Filter: QKV layout.
if use_unfused_attention and qkv_format == "thd":
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
# Filter: ONNX export.
if is_in_onnx_export_mode():
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
self.logger.debug("Disabling FlashAttention for ONNX mode")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for ONNX mode")
use_fused_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
#
if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
# Filter: Input type.
if use_flash_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
self.logger.debug(
"Disabling FlashAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_flash_attention = False
if use_fused_attention and (
query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
self.logger.debug(
"Disabling FusedAttention due to unsupported QKV data types. "
"Supported: [torch.bfloat16, torch.float16, Float8Tensor]. "
"Found: query_layer.dtype=%s, key_layer.dtype=%s, value_layer.dtype=%s.",
query_layer.dtype,
key_layer.dtype,
value_layer.dtype,
)
use_fused_attention = False
if (
use_unfused_attention
and inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: Execution type.
if use_flash_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug("Disabling FlashAttention as it does not support FP8 execution.")
use_flash_attention = False
if use_unfused_attention and self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
self.logger.debug(
"Disabling UnfusedDotProductAttention as it does not support FP8 execution."
)
use_unfused_attention = False
# Filter: Device and dimensions.
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# FAv2 requires head_dim % 8 == 0
if use_flash_attention and (
query_layer.shape[-1] > 256
or query_layer.shape[-1] % 8 != 0
or (
query_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0))
)
):
self.logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: %%8 == 0, and <= 256; sm80/90 for >192. "
"Found: query_layer.shape[-1]=%s, key_layer.shape[-1]=%s, sm=%s",
query_layer.shape[-1],
key_layer.shape[-1],
".".join([str(i) for i in self.device_compute_capability]),
)
use_flash_attention = False
# Filter: bias.
global _alibi_cache
if alibi_slopes is not None:
assert (
core_attention_bias_type == "alibi"
), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
# Filter: cross attention + causal mask.
# (in training mode)
if (
_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_alibi_slopes"] is None
use_flash_attention
and inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
self.logger.warning(
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None
):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
if (
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
context_parallel = (
self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
)
):
if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
# Filter: sliding window attention.
# UnfusedDotProductAttention can support SWA via arbitrary attention mask.
if window_size not in ((-1, -1), (-1, 0)):
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for SWA")
use_fused_attention = False
if (not _flash_attn_2_3_plus) or context_parallel:
if use_flash_attention:
self.logger.debug(
"Disabling FusedAttention as it requires flash-attn 2.3+ "
"and no context parallelism"
)
use_flash_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# no_mask | All
# padding | UnfusedDotProductAttention, FlashAttention, FusedAttention
# causal | All
# padding + causal | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
#
if attn_mask_type == "arbitrary":
if use_flash_attention:
self.logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
self.logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
(
TE_DType[query_layer.dtype]
if not isinstance(query_layer, Float8Tensor)
else query_layer._fp8_dtype
),
(
TE_DType[key_layer.dtype]
if not isinstance(key_layer, Float8Tensor)
else key_layer._fp8_dtype
),
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"],
]
use_fused_attention = (
use_fused_attention
and is_backend_avail
and (
not context_parallel
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
if (
use_unfused_attention
and inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
# Filter: bias.
global _alibi_cache
if alibi_slopes is not None:
assert (
core_attention_bias_type == "alibi"
), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
if (
_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_alibi_slopes"] is None
):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias is not None
):
self.logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if (
core_attention_bias_type == "alibi"
and use_fused_attention
and alibi_slopes is not None
):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
)
)
if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
):
self.logger.debug(
"Disabling FusedAttention as no backend supports the provided input"
if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
self.logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention:
q_type = TE_DType[query_layer.dtype]
kv_type = TE_DType[key_layer.dtype]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
if isinstance(query_layer, Float8Tensor) and isinstance(
key_layer, Float8Tensor
):
q_type = query_layer._fp8_dtype
kv_type = value_layer._fp8_dtype
else:
q_type = forward_dtype
kv_type = forward_dtype
fused_attention_backend = tex.get_fused_attn_backend(
q_type,
kv_type,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
key_layer.shape[-2], # num_gqa_groups
max_seqlen_q,
max_seqlen_kv,
query_layer.shape[-1], # head_dim
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"],
]
use_fused_attention = (
use_fused_attention
and is_backend_avail
and (
not context_parallel
or fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
)
)
if (
fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (
fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]
)
):
self.logger.debug(
"Disabling FusedAttention as no backend supports the provided input"
)
use_fused_attention = False
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
# flash-attn v1 | yes
# flash-attn v2 | no
# FusedAttnBackend["F16_max512_seqlen"] | yes
# FusedAttnBackend["F16_arbitrary_seqlen"] | workspace optimization path: yes; otherwise: no
# UnfusedDotProductAttention | yes
#
# Note that FusedAttnBackend["F16_arbitrary_seqlen"] only has workspace optimization path
# on sm90 architectures.
#
if (
use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and self.deterministic
and self.device_compute_capability != (9, 0)
):
self.logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if self.device_compute_capability == (9, 0):
self.logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
# Select FusedAttention on sm90 and FlashAttention on others for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if self.device_compute_capability == (9, 0):
self.logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
run_config = {
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(
self.device_compute_capability[0], self.device_compute_capability[1]
)
),
"q_dtype": query_layer.dtype,
"k_dtype": key_layer.dtype,
"v_dtype": value_layer.dtype,
"q_shape": list(query_layer.shape),
"k_shape": list(key_layer.shape),
"v_shape": list(value_layer.shape),
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": (
core_attention_bias.shape if core_attention_bias is not None else None
),
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
run_config = {
"compute_capability": "sm"
+ str(
(lambda x, y: x * 10 + y)(
self.device_compute_capability[0], self.device_compute_capability[1]
if use_flash_attention:
self.logger.info("Running with FlashAttention backend ")
self.logger.debug("Running with config=%s", run_config)
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
query_layer.shape[-2],
max_seqlen_q,
max_seqlen_kv,
alibi_slopes=alibi_slopes,
)
return self.flash_attention(
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes=alibi_slopes,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
),
"q_dtype": query_layer.dtype,
"k_dtype": key_layer.dtype,
"v_dtype": value_layer.dtype,
"q_shape": list(query_layer.shape),
"k_shape": list(key_layer.shape),
"v_shape": list(value_layer.shape),
"qkv_format": qkv_format,
"qkv_layout": qkv_layout,
"mask_type": attn_mask_type,
"bias_type": core_attention_bias_type,
"bias_shape": core_attention_bias.shape if core_attention_bias is not None else None,
"dropout": self.attention_dropout,
"context_parallel": context_parallel,
"is_training": self.training,
"transformer_engine_version": te.__version__,
"flash_attn_version": _flash_attn_version,
"cudnn_version": ".".join([str(i) for i in get_cudnn_version()]),
}
if use_flash_attention:
self.logger.info("Running with FlashAttention backend ")
self.logger.debug("Running with config=%s", run_config)
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes
if use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
return self.flash_attention(
query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes=alibi_slopes,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
if use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)", int(fused_attention_backend)
)
self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
if self.fp8:
self.logger.debug(
"Running with fp8_recipe.fp8_mha=%s, "
"fp8_recipe.fp8_dpa=%s%s, and NVTE_FP8_DPA_BWD=%s",
self.fp8_meta["recipe"].fp8_mha,
self.fp8_meta["recipe"].fp8_dpa,
forced_fp8_dpa,
int(os.getenv("NVTE_FP8_DPA_BWD", "1")),
)
self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
return self.fused_attention(
query_layer,
key_layer,
value_layer,
......@@ -5300,51 +5389,41 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
return self.fused_attention(
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
is_first_microbatch=is_first_microbatch,
)
assert (
not context_parallel
), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
assert (
not context_parallel
), "Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
if use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
self.logger.debug("Running with config=%s", run_config)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
)
return self.unfused_attention(
query_layer,
key_layer,
value_layer,
......@@ -5357,21 +5436,8 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
)
return self.unfused_attention(
query_layer,
key_layer,
value_layer,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
)
raise Exception("No dot product attention support for the provided inputs!")
raise Exception("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module):
......
......@@ -430,7 +430,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, tuple, list)):
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
......@@ -491,12 +493,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
for name, buf in self.named_buffers():
if buf is not None:
assert dtype == buf.dtype, (
"Data types for buffers must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
)
self.activation_dtype = dtype
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment