Unverified Commit 257345a5 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] Fix CP implementation with FP8 (#1483)



* commit some debug code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more debug info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* debug code commit and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* a typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove debug info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not return lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add amax_per_step for quantizers of CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FP8 + CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* dtype fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@login-preos01.a51.clusters.nvidia.com>
parent b612cdeb
......@@ -1894,11 +1894,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_backend = None
qkv_dtype = q.dtype
amax_per_step = None
S_quantizer_per_step = [None for _ in range(cp_size)]
O_CP_quantizer_per_step = [None for _ in range(cp_size)]
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = False
if fp8:
is_output_fp8 = fp8_meta["recipe"].fp8_mha
(
QKV_quantizer,
......@@ -1919,28 +1920,30 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
if not is_input_fp8:
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if is_input_fp8:
QKV_quantizer = q._quantizer
q, k, v = q._data, k._data, v._data
else:
q_f16, k_f16, v_f16 = q, k, v
if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = QKV_quantizer(q_f16)
q = QKV_quantizer(q_f16)._data
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]]
fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = S_quantizer
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer
k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]]
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
# partial result quantizer
for i in range(cp_size):
S_quantizer_per_step[i] = S_quantizer.copy()
S_quantizer_per_step[i].amax = amax_per_step[0][i]
O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
O_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
q_f16 = q
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if fp8:
q = q._data
k = k._data
v = v._data
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
......@@ -2067,7 +2070,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = p2p_comm_buffers[i]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if causal:
if i == 0:
if pad_between_seqs_q:
......@@ -2120,6 +2123,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fp8_meta_kwargs = {}
if fp8:
q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True
......@@ -2130,6 +2134,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True
)
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
......@@ -2243,6 +2249,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fp8_meta_kwargs = {}
if fp8:
q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True
......@@ -2253,6 +2260,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True
)
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
......@@ -2385,6 +2394,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fp8_meta_kwargs = {}
if fp8:
q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True
......@@ -2395,6 +2405,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True
)
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q // 2,
......@@ -2507,6 +2519,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fp8_meta_kwargs = {}
if fp8:
q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True
......@@ -2517,6 +2530,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True
)
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
......@@ -2595,7 +2610,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize()
out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1:
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
......@@ -2697,6 +2712,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])
if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
S_quantizer.amax = amax_cp_fwd[0]
O_CP_quantizer.amax = amax_cp_fwd[1]
out_fp8 = None
out_f16 = out.to(qkv_dtype)
......@@ -2708,7 +2728,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, kv_save, out_save = q, kv, out_fp8._data
elif fp8 and is_input_fp8:
q_save, kv_save, out_save = q, k, out_f16
q_save, kv_save, out_save = q, kv, out_f16
else:
q_f16 = q_f16.view(q.shape)
q_save, kv_save, out_save = q_f16, kv, out_f16
......@@ -2737,7 +2757,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.qkv_dtype = qkv_dtype
ctx.cp_group_a2a = cp_group_a2a
ctx.cp_size_a2a = cp_size_a2a
......@@ -2778,10 +2797,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
saved_tensors = ctx.saved_tensors
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, saved_tensors)
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
......@@ -2843,17 +2860,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_dqkv_dtype = None
amax_per_step = None
dP_quantizer_per_step = [None for _ in range(cp_size)]
dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)]
if ctx.fp8:
if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
dq_fp8 = torch.empty(
(cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
fused_attn_dqkv_dtype = dout._fp8_dtype
dout = dout._data
ctx.dO_quantizer = dout._quantizer
else:
dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype
......@@ -2861,21 +2887,32 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
for i in range(cp_size):
dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
dP_quantizer_per_step[i].amax = amax_per_step[0][i]
dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.is_input_fp8:
if ctx.fp8_meta is not None:
if ctx.is_input_fp8:
q = ctx.QKV_quantizer.create_tensor_from_data(
q, fake_dtype=ctx.qkv_dtype, internal=True
)
kv = ctx.QKV_quantizer.create_tensor_from_data(
kv, fake_dtype=ctx.qkv_dtype, internal=True
)
q, kv = q.dequantize(), kv.dequantize()
q = q.dequantize(dtype=ctx.qkv_dtype)
kv = kv.dequantize(dtype=ctx.qkv_dtype)
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
if cp_size_a2a == 1:
dout = dout.dequantize()
dout = dout.dequantize(dtype=dout_dtype)
else:
ctx.dO_quantizer = dout._quantizer
dout = dout._data
dq = torch.empty_like(q)
p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
......@@ -2902,9 +2939,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True,
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True)
dout = dout.dequantize()
dout = dout._data
dout = ctx.dO_quantizer.create_tensor_from_data(
dout, fake_dtype=dout_dtype, internal=True
)
dout = dout.dequantize(dtype=dout_dtype)
out = out.view(*q.shape)
dout = dout.view(*q.shape)
......@@ -3020,8 +3058,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True
)
dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True
dout_part, fake_dtype=dout_dtype, internal=True
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -3133,8 +3173,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True
)
dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True
dout_part, fake_dtype=dout_dtype, internal=True
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
......@@ -3250,8 +3292,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True
)
dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True
dout_part, fake_dtype=dout_dtype, internal=True
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
......@@ -3282,7 +3326,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq_ = dq_._data
dk_ = dk_._data
dv_ = dv_._data
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
......@@ -3333,20 +3376,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8:
q_part = ctx.QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=ctx.qkv_dtype
q_part, fake_dtype=ctx.qkv_dtype, internal=True
)
k_part = ctx.QKV_quantizer.create_tensor_from_data(
k_part, fake_dtype=ctx.qkv_dtype
k_part, fake_dtype=ctx.qkv_dtype, internal=True
)
v_part = ctx.QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=ctx.qkv_dtype
v_part, fake_dtype=ctx.qkv_dtype, internal=True
)
out_part = ctx.O_quantizer.create_tensor_from_data(
out_part, fake_dtype=ctx.qkv_dtype
out_part, fake_dtype=ctx.qkv_dtype, internal=True
)
dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype
dout_part, fake_dtype=dout_dtype, internal=True
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -3555,13 +3600,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax = amax_cp_bwd[0]
ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1]
if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8)
dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8)
dq, dkv = [x.dequantize() for x in [dq, dkv]]
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv_fp8, fake_dtype=torch.float32, internal=True
)
dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]]
dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]
if causal:
......@@ -3606,9 +3658,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
# converting torch.uint8 to float8tensor
if ctx.fp8 and ctx.is_input_fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
return (
......@@ -4227,21 +4279,20 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = False
if fp8:
is_output_fp8 = fp8_meta["recipe"].fp8_mha
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
)
if fp8:
if use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if is_input_fp8:
QKV_quantizer = q._quantizer
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
......@@ -4350,31 +4401,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out = out_fp8._data
else:
out_fp8 = O_quantizer.create_tensor_from_data(
out, fake_dtype=qkv_dtype, internal=False
out, fake_dtype=qkv_dtype, internal=True
)
out_f16 = out_fp8.dequantize()
out_f16 = out_fp8.dequantize(dtype=qkv_dtype)
out_ret = out_f16
else:
out_ret = out
if fp8:
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, k_save, v_save, out_save = q, k, v, out
elif is_input_fp8:
q_fp8 = QKV_quantizer.create_tensor_from_data(
q, fake_dtype=qkv_dtype, internal=False
)
k_fp8 = QKV_quantizer.create_tensor_from_data(
k, fake_dtype=qkv_dtype, internal=False
)
v_fp8 = QKV_quantizer.create_tensor_from_data(
v, fake_dtype=qkv_dtype, internal=False
)
q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out
else:
q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
if is_input_fp8:
q_save, k_save, v_save = q, k, v
else:
q_save, k_save, v_save, out_save = q, k, v, out
q_save, k_save, v_save = q_f16, k_f16, v_f16
if is_output_fp8:
out_save = out
else:
out_save = out_f16
tensors_to_save, tensor_objects = prepare_for_saving(
q_save,
......@@ -4397,7 +4441,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.qkv_dtype = qkv_dtype
ctx.batch_size = batch_size
ctx.cp_group = cp_group
......@@ -4436,27 +4479,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded,
*aux_ctx_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
dout_dtype = dout.dtype
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s")
dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_dqkv_dtype = None
if ctx.fp8:
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_dqkv_dtype = fp8_dtype_backward
if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"]
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
dout_fp8 = dout
dout = dout_fp8._data
ctx.dO_quantizer = dout._quantizer
else:
dout_f16 = dout
dout = ctx.dO_quantizer(dout_f16)._data
dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype
dout = dout._data
fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
......@@ -4465,12 +4505,25 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.is_output_fp8:
if ctx.fp8_meta is not None:
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]]
ctx.dO_quantizer = dout._quantizer
dout = dout._data
if ctx.is_input_fp8:
q = ctx.QKV_quantizer.create_tensor_from_data(
q, fake_dtype=ctx.qkv_dtype, internal=True
)
k = ctx.QKV_quantizer.create_tensor_from_data(
k, fake_dtype=ctx.qkv_dtype, internal=True
)
v = ctx.QKV_quantizer.create_tensor_from_data(
v, fake_dtype=ctx.qkv_dtype, internal=True
)
q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]]
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_dqkv_dtype = TE_DType[dout_dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if not ctx.use_fused_attention:
......@@ -4481,6 +4534,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
out = ctx.O_quantizer.create_tensor_from_data(
out, fake_dtype=ctx.qkv_dtype, internal=True
)
dout = ctx.dO_quantizer.create_tensor_from_data(
dout, fake_dtype=dout_dtype, internal=True
)
out = out.dequantize(dtype=ctx.qkv_dtype)
dout = dout.dequantize(dtype=dout_dtype)
flash_attn_bwd = None
if not ctx.use_fused_attention:
......@@ -4531,7 +4593,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True
)
dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True
dout_part, fake_dtype=dout_dtype, internal=True
)
dq, dk, dv, _ = fused_attn_bwd(
......@@ -4602,11 +4664,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
dq = ctx.dQKV_quantizer.create_tensor_from_data(
dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
)
dk = ctx.dQKV_quantizer.create_tensor_from_data(
dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
)
dv = ctx.dQKV_quantizer.create_tensor_from_data(
dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
)
if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
return (
......
......@@ -56,7 +56,7 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2fn
return torch.float8_e5m2
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment