Unverified Commit 8487e506 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix fused attention backward's FP8 dtypes (#1566)



* fix dtypes in fused attn bwd for FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add comments for dtypes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant qkv_dtype in fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove Nones in bwd returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent ab4fd3cf
...@@ -6095,7 +6095,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6095,7 +6095,6 @@ class FusedAttnFunc(torch.autograd.Function):
q, q,
k, k,
v, v,
qkv_dtype,
attn_bias, attn_bias,
attn_scale, attn_scale,
dropout_p, dropout_p,
...@@ -6116,6 +6115,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6116,6 +6115,10 @@ class FusedAttnFunc(torch.autograd.Function):
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype fake_dtype = q.dtype
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
...@@ -6154,6 +6157,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6154,6 +6157,7 @@ class FusedAttnFunc(torch.autograd.Function):
v_fp8 = QKV_quantizer(v) v_fp8 = QKV_quantizer(v)
case _: case _:
raise "Invalid qkv_layout " + qkv_layout raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd( out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -6183,6 +6187,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6183,6 +6187,8 @@ class FusedAttnFunc(torch.autograd.Function):
out_ret = out_fp8 out_ret = out_fp8
else: else:
out_ret = out_fp8.dequantize().view(out_fp8.shape) out_ret = out_fp8.dequantize().view(out_fp8.shape)
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn
out_save = out_ret out_save = out_ret
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
...@@ -6211,7 +6217,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6211,7 +6217,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else: else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd( out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -6280,8 +6286,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6280,8 +6286,6 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
ctx.fake_dtype = fake_dtype
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
...@@ -6305,6 +6309,11 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6305,6 +6309,11 @@ class FusedAttnFunc(torch.autograd.Function):
d_out, Float8Tensor d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
d_out = d_out.contiguous() d_out = d_out.contiguous()
( (
q_fp8, q_fp8,
...@@ -6364,6 +6373,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6364,6 +6373,9 @@ class FusedAttnFunc(torch.autograd.Function):
d_out_fp8 = d_out d_out_fp8 = d_out
else: else:
d_out_fp8 = ctx.dO_quantizer(d_out) d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -6374,8 +6386,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6374,8 +6386,8 @@ class FusedAttnFunc(torch.autograd.Function):
v_fp8, v_fp8,
out_fp8, out_fp8,
d_out_fp8, d_out_fp8,
ctx.fake_dtype, fake_dtype,
ctx.qkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -6393,6 +6405,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6393,6 +6405,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.deterministic, ctx.deterministic,
) )
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8: if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.split("_")) qkv_group = len(ctx.qkv_layout.split("_"))
if qkv_group == 1: if qkv_group == 1:
...@@ -6423,6 +6437,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6423,6 +6437,8 @@ class FusedAttnFunc(torch.autograd.Function):
else: else:
if isinstance(d_out, QuantizedTensor): if isinstance(d_out, QuantizedTensor):
d_out = d_out.dequantize() d_out = d_out.dequantize()
dqkv_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -6433,8 +6449,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6433,8 +6449,8 @@ class FusedAttnFunc(torch.autograd.Function):
v, v,
out, out,
d_out, d_out,
ctx.fake_dtype, fake_dtype,
ctx.qkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -6482,7 +6498,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6482,7 +6498,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -6496,7 +6511,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -6496,7 +6511,6 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dq,
dk, dk,
dv, dv,
None,
rest[0], rest[0],
None, None,
None, None,
...@@ -6695,8 +6709,6 @@ class FusedAttention(torch.nn.Module): ...@@ -6695,8 +6709,6 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv cu_seqlens_kv_padded = cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype]
use_FAv2_bwd = ( use_FAv2_bwd = (
self.use_FAv2_bwd self.use_FAv2_bwd
and (core_attention_bias_type == "no_bias") and (core_attention_bias_type == "no_bias")
...@@ -6768,7 +6780,6 @@ class FusedAttention(torch.nn.Module): ...@@ -6768,7 +6780,6 @@ class FusedAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
qkv_dtype,
core_attention_bias, core_attention_bias,
self.softmax_scale, self.softmax_scale,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment