Unverified Commit 8487e506 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix fused attention backward's FP8 dtypes (#1566)



* fix dtypes in fused attn bwd for FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add comments for dtypes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant qkv_dtype in fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove Nones in bwd returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent ab4fd3cf
......@@ -6095,7 +6095,6 @@ class FusedAttnFunc(torch.autograd.Function):
q,
k,
v,
qkv_dtype,
attn_bias,
attn_scale,
dropout_p,
......@@ -6116,6 +6115,10 @@ class FusedAttnFunc(torch.autograd.Function):
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
......@@ -6154,6 +6157,7 @@ class FusedAttnFunc(torch.autograd.Function):
v_fp8 = QKV_quantizer(v)
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
......@@ -6183,6 +6187,8 @@ class FusedAttnFunc(torch.autograd.Function):
out_ret = out_fp8
else:
out_ret = out_fp8.dequantize().view(out_fp8.shape)
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn
out_save = out_ret
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
......@@ -6211,7 +6217,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
......@@ -6280,8 +6286,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.fake_dtype = fake_dtype
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
......@@ -6305,6 +6309,11 @@ class FusedAttnFunc(torch.autograd.Function):
d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
d_out = d_out.contiguous()
(
q_fp8,
......@@ -6364,6 +6373,9 @@ class FusedAttnFunc(torch.autograd.Function):
d_out_fp8 = d_out
else:
d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -6374,8 +6386,8 @@ class FusedAttnFunc(torch.autograd.Function):
v_fp8,
out_fp8,
d_out_fp8,
ctx.fake_dtype,
ctx.qkv_dtype,
fake_dtype,
dqkv_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -6393,6 +6405,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.deterministic,
)
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.split("_"))
if qkv_group == 1:
......@@ -6423,6 +6437,8 @@ class FusedAttnFunc(torch.autograd.Function):
else:
if isinstance(d_out, QuantizedTensor):
d_out = d_out.dequantize()
dqkv_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -6433,8 +6449,8 @@ class FusedAttnFunc(torch.autograd.Function):
v,
out,
d_out,
ctx.fake_dtype,
ctx.qkv_dtype,
fake_dtype,
dqkv_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -6482,7 +6498,6 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
......@@ -6496,7 +6511,6 @@ class FusedAttnFunc(torch.autograd.Function):
dq,
dk,
dv,
None,
rest[0],
None,
None,
......@@ -6695,8 +6709,6 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype]
use_FAv2_bwd = (
self.use_FAv2_bwd
and (core_attention_bias_type == "no_bias")
......@@ -6768,7 +6780,6 @@ class FusedAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
qkv_dtype,
core_attention_bias,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment