"examples/vscode:/vscode.git/clone" did not exist on "0edf30b87159e82048b5f248e4b379aebb8f364a"
Unverified Commit c42beef4 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Make FP8 MHA work with RoPE when CP is on (#1297)



* Let fp8 mha work with rope when cp is on
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix and update ut
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent a6a9141b
...@@ -11,16 +11,24 @@ from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank ...@@ -11,16 +11,24 @@ from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp( def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
): ):
"""Test DotProductAttention module with context parallelism""" """Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention": if kernel_backend == "FlashAttention":
...@@ -72,7 +80,7 @@ def run_dpa_with_cp( ...@@ -72,7 +80,7 @@ def run_dpa_with_cp(
cp_comm_sub_groups.append(sub_group) cp_comm_sub_groups.append(sub_group)
if dtype == "fp8": if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True) fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module # instantiate core attn module
core_attn = DotProductAttention( core_attn = DotProductAttention(
...@@ -201,7 +209,11 @@ def run_dpa_with_cp( ...@@ -201,7 +209,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
), ),
) )
out.backward(dout) if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
out.backward(dout_fp8)
else:
out.backward(dout)
# run core_attn wit CP # run core_attn wit CP
q_, k_, v_, dout_, *rest = [ q_, k_, v_, dout_, *rest = [
...@@ -269,7 +281,11 @@ def run_dpa_with_cp( ...@@ -269,7 +281,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
), ),
) )
out_.backward(dout_) if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]: for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isnan(x))
......
...@@ -1356,8 +1356,6 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1356,8 +1356,6 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if _flash_attn_3_is_installed and not is_training: if _flash_attn_3_is_installed and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
......
...@@ -113,7 +113,8 @@ model_configs_fused_attn = { ...@@ -113,7 +113,8 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0): if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!") pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
...@@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
) )
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
...@@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FusedAttention", kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
), ),
check=True, check=True,
) )
...@@ -1729,17 +1729,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1729,17 +1729,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_qkv_dtype = None fused_attn_qkv_dtype = None
fused_attn_backend = None fused_attn_backend = None
amax_per_step = None amax_per_step = None
qkv_dtype = q.dtype
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
if use_fused_attention: if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha: assert isinstance(k, q.__class__) and isinstance(
assert ( v, q.__class__
isinstance(q, Float8Tensor) ), "q, k, and v must have the same type."
and isinstance(k, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
and isinstance(v, Float8Tensor) if is_input_fp8:
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
...@@ -1778,7 +1781,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1778,7 +1781,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
if not fp8: if not fp8:
q_f16 = q q_f16 = q
elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16 = q q_f16 = q
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
...@@ -1880,11 +1883,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1880,11 +1883,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
batch_p2p_comm, batch_p2p_comm,
) )
if ( if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
not fp8
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
kv_inputs[i % 2] = p2p_comm_buffers[i] kv_inputs[i % 2] = p2p_comm_buffers[i]
else: else:
# KV exchange is in BF16/FP16, cast received KV in each step # KV exchange is in BF16/FP16, cast received KV in each step
...@@ -2436,18 +2435,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2436,18 +2435,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]
out_fp8 = None out_fp8 = None
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) out_f16 = out.to(qkv_dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and is_output_fp8:
out_ret = Float8Tensor( out_ret = Float8Tensor(
data=out_fp8, data=out_fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=META_O, fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward, fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype, dtype=qkv_dtype,
) )
else: else:
out_ret = out_f16 out_ret = out_f16
...@@ -2456,7 +2455,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2456,7 +2455,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_save, kv_save, out_save = q, kv, out_fp8 q_save, kv_save, out_save = q, kv, out_fp8
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
elif fp8 and fp8_meta["recipe"].fp8_mha: elif fp8 and is_input_fp8:
q_fp8 = Float8Tensor( q_fp8 = Float8Tensor(
data=q, data=q,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
...@@ -2513,6 +2512,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2513,6 +2512,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.use_fused_attention = use_fused_attention ctx.use_fused_attention = use_fused_attention
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
return out_ret return out_ret
@staticmethod @staticmethod
...@@ -2595,7 +2596,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2595,7 +2596,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dkv_fp8_ = torch.empty_like(dkv_fp8) dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout = dout._data dout = dout._data
...@@ -2617,7 +2618,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2617,7 +2618,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta is not None and ctx.is_input_fp8:
q, kv = [x.from_float8(x.dtype) for x in [q, kv]] q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
if cp_size_a2a == 1: if cp_size_a2a == 1:
dout = dout.from_float8(dout_dtype) dout = dout.from_float8(dout_dtype)
...@@ -2653,7 +2654,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2653,7 +2654,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.cp_stream, ctx.cp_stream,
True, True,
) )
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
dout = cast_from_fp8( dout = cast_from_fp8(
dout, dout,
None, None,
...@@ -3260,7 +3261,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3260,7 +3261,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_ dkv = dkv_
if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8 and ctx.is_input_fp8:
dq, dkv = [ dq, dkv = [
cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
for x in [dq, dkv] for x in [dq, dkv]
...@@ -3283,7 +3284,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3283,7 +3284,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8 and ctx.is_input_fp8:
dq, dk, dv = [ dq, dk, dv = [
Float8Tensor( Float8Tensor(
data=x, data=x,
...@@ -3852,19 +3853,22 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3852,19 +3853,22 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!" ), "Sequence length per GPU needs to be divisible by 2!"
qkv_dtype = q.dtype
fused_attn_backend = None fused_attn_backend = None
fused_attn_qkv_dtype = None fused_attn_qkv_dtype = None
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
if use_fused_attention: if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha: assert isinstance(k, q.__class__) and isinstance(
assert ( v, q.__class__
isinstance(q, Float8Tensor) ), "q, k, and v must have the same type."
and isinstance(k, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
and isinstance(v, Float8Tensor) if is_input_fp8:
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
...@@ -3900,7 +3904,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3900,7 +3904,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
) )
if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v q_f16, k_f16, v_f16 = q, k, v
q, k, v = [ q, k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
...@@ -3965,14 +3969,14 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3965,14 +3969,14 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out = out.view(-1, batch_size, *out.shape[-2:]) out = out.view(-1, batch_size, *out.shape[-2:])
if fp8: if fp8:
if fp8_meta["recipe"].fp8_mha: if is_output_fp8:
out_fp8 = Float8Tensor( out_fp8 = Float8Tensor(
data=out, data=out,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=META_O, fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward, fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype, dtype=qkv_dtype,
) )
out = out_fp8._data out = out_fp8._data
out_ret = out_fp8 out_ret = out_fp8
...@@ -3991,7 +3995,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3991,7 +3995,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if fp8: if fp8:
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, k_save, v_save, out_save = q, k, v, out q_save, k_save, v_save, out_save = q, k, v, out
elif fp8_meta["recipe"].fp8_mha: elif is_input_fp8:
q_fp8, k_fp8, v_fp8 = [ q_fp8, k_fp8, v_fp8 = [
Float8Tensor( Float8Tensor(
data=x, data=x,
...@@ -4043,6 +4047,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4043,6 +4047,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.use_fused_attention = use_fused_attention ctx.use_fused_attention = use_fused_attention
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
return out_ret return out_ret
@staticmethod @staticmethod
...@@ -4064,6 +4070,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4064,6 +4070,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fused_attn_backend = None fused_attn_backend = None
fused_attn_dqkv_dtype = None fused_attn_dqkv_dtype = None
fused_attn_qkv_dtype = None fused_attn_qkv_dtype = None
dout_dtype = dout.dtype
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
...@@ -4071,7 +4078,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4071,7 +4078,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout_fp8 = dout dout_fp8 = dout
...@@ -4097,7 +4104,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4097,7 +4104,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta is not None and ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
if ctx.use_fused_attention: if ctx.use_fused_attention:
...@@ -4194,7 +4201,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4194,7 +4201,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8: if ctx.fp8:
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_input_fp8:
dq, dk, dv = [ dq, dk, dv = [
Float8Tensor( Float8Tensor(
data=x, data=x,
...@@ -4202,7 +4209,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4202,7 +4209,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fp8_meta_forward=False, fp8_meta_forward=False,
fp8_meta_index=META_DQKV, fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward, fp8_dtype=fp8_dtype_backward,
dtype=dout_fp8.dtype, dtype=dout_dtype,
) )
for x in [dq, dk, dv] for x in [dq, dk, dv]
] ]
...@@ -4213,7 +4220,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4213,7 +4220,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
META_DQKV, META_DQKV,
fp8_dtype_backward, fp8_dtype_backward,
TE_DType[dout_f16.dtype], TE_DType[dout_dtype],
) )
for x in [dq, dk, dv] for x in [dq, dk, dv]
] ]
...@@ -5434,11 +5441,12 @@ class FlashAttention(torch.nn.Module): ...@@ -5434,11 +5441,12 @@ class FlashAttention(torch.nn.Module):
) )
return out return out
if fp8_meta["recipe"].fp8_mha: # "fp8_mha" decides outputs in fp8, while inputs are inferred from
assert all( # the real dtype
isinstance(x, Float8Tensor) assert isinstance(key_layer, query_layer.__class__) and isinstance(
for x in [query_layer, key_layer, value_layer] value_layer, query_layer.__class__
), "q/k/v must be Float8Tensors for FP8 MHA." ), "q, k, and v must have the same type."
if isinstance(query_layer, Float8Tensor):
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
else: else:
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
...@@ -5580,6 +5588,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -5580,6 +5588,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
...@@ -5970,6 +5979,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5970,6 +5979,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
...@@ -6424,6 +6434,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6424,6 +6434,7 @@ class FusedAttnFunc(torch.autograd.Function):
deterministic, deterministic,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment