Unverified Commit 257345a5 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

[PyTorch] Fix CP implementation with FP8 (#1483)



* commit some debug code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add more debug info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* debug code commit and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* a typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove debug info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not return lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add amax_per_step for quantizers of CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FP8 + CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* dtype fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@login-preos01.a51.clusters.nvidia.com>
parent b612cdeb
...@@ -1894,11 +1894,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1894,11 +1894,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fused_attn_backend = None fused_attn_backend = None
qkv_dtype = q.dtype qkv_dtype = q.dtype
amax_per_step = None
S_quantizer_per_step = [None for _ in range(cp_size)]
O_CP_quantizer_per_step = [None for _ in range(cp_size)]
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = False is_output_fp8 = False
if fp8:
is_output_fp8 = fp8_meta["recipe"].fp8_mha
( (
QKV_quantizer, QKV_quantizer,
...@@ -1919,28 +1920,30 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1919,28 +1920,30 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v, q.__class__ v, q.__class__
), "q, k, and v must have the same type." ), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
if not is_input_fp8: is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if is_input_fp8:
QKV_quantizer = q._quantizer
q, k, v = q._data, k._data, v._data
else:
q_f16, k_f16, v_f16 = q, k, v q_f16, k_f16, v_f16 = q, k, v
if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = QKV_quantizer(q_f16) q = QKV_quantizer(q_f16)._data
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]] k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]]
fp8_meta_kwargs = {} amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
fp8_meta_kwargs["s_quantizer"] = S_quantizer # partial result quantizer
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer for i in range(cp_size):
S_quantizer_per_step[i] = S_quantizer.copy()
S_quantizer_per_step[i].amax = amax_per_step[0][i]
O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
O_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
q_f16 = q q_f16 = q
if use_fused_attention: if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if fp8:
q = q._data
k = k._data
v = v._data
if cp_size_a2a > 1: if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True)
...@@ -2067,7 +2070,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2067,7 +2070,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_inputs[i % 2] = p2p_comm_buffers[i] kv_inputs[i % 2] = p2p_comm_buffers[i]
else: else:
# KV exchange is in BF16/FP16, cast received KV in each step # KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i]) kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs_q: if pad_between_seqs_q:
...@@ -2120,6 +2123,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2120,6 +2123,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
) )
fp8_meta_kwargs = {}
if fp8: if fp8:
q_part = QKV_quantizer.create_tensor_from_data( q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True q_part, fake_dtype=qkv_dtype, internal=True
...@@ -2130,6 +2134,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2130,6 +2134,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data( v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True v_part, fake_dtype=qkv_dtype, internal=True
) )
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
...@@ -2243,6 +2249,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2243,6 +2249,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
) )
fp8_meta_kwargs = {}
if fp8: if fp8:
q_part = QKV_quantizer.create_tensor_from_data( q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True q_part, fake_dtype=qkv_dtype, internal=True
...@@ -2253,6 +2260,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2253,6 +2260,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data( v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True v_part, fake_dtype=qkv_dtype, internal=True
) )
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -2385,6 +2394,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2385,6 +2394,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
) )
fp8_meta_kwargs = {}
if fp8: if fp8:
q_part = QKV_quantizer.create_tensor_from_data( q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True q_part, fake_dtype=qkv_dtype, internal=True
...@@ -2395,6 +2405,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2395,6 +2405,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data( v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True v_part, fake_dtype=qkv_dtype, internal=True
) )
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q // 2, max_seqlen_q // 2,
...@@ -2507,6 +2519,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2507,6 +2519,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
) )
fp8_meta_kwargs = {}
if fp8: if fp8:
q_part = QKV_quantizer.create_tensor_from_data( q_part = QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=qkv_dtype, internal=True q_part, fake_dtype=qkv_dtype, internal=True
...@@ -2517,6 +2530,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2517,6 +2530,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part = QKV_quantizer.create_tensor_from_data( v_part = QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=qkv_dtype, internal=True v_part, fake_dtype=qkv_dtype, internal=True
) )
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i]
fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i]
out_per_step[i], aux_ctx_tensors = fused_attn_fwd( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
...@@ -2595,7 +2610,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2595,7 +2610,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8: if fp8:
out_per_step[i - 1] = out_per_step[i - 1].dequantize() out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32)
if i == 1: if i == 1:
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
...@@ -2697,6 +2712,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2697,6 +2712,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif not use_fused_attention: elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:]) out = out.view(-1, *out.shape[-2:])
if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
S_quantizer.amax = amax_cp_fwd[0]
O_CP_quantizer.amax = amax_cp_fwd[1]
out_fp8 = None out_fp8 = None
out_f16 = out.to(qkv_dtype) out_f16 = out.to(qkv_dtype)
...@@ -2708,7 +2728,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2708,7 +2728,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, kv_save, out_save = q, kv, out_fp8._data q_save, kv_save, out_save = q, kv, out_fp8._data
elif fp8 and is_input_fp8: elif fp8 and is_input_fp8:
q_save, kv_save, out_save = q, k, out_f16 q_save, kv_save, out_save = q, kv, out_f16
else: else:
q_f16 = q_f16.view(q.shape) q_f16 = q_f16.view(q.shape)
q_save, kv_save, out_save = q_f16, kv, out_f16 q_save, kv_save, out_save = q_f16, kv, out_f16
...@@ -2737,7 +2757,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2737,7 +2757,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.qkv_dtype = qkv_dtype
ctx.cp_group_a2a = cp_group_a2a ctx.cp_group_a2a = cp_group_a2a
ctx.cp_size_a2a = cp_size_a2a ctx.cp_size_a2a = cp_size_a2a
...@@ -2778,10 +2797,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2778,10 +2797,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
saved_tensors = ctx.saved_tensors
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, saved_tensors) restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
) )
cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
...@@ -2843,17 +2860,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2843,17 +2860,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout_dtype = dout.dtype dout_dtype = dout.dtype
fused_attn_backend = None fused_attn_backend = None
fused_attn_dqkv_dtype = None fused_attn_dqkv_dtype = None
amax_per_step = None
dP_quantizer_per_step = [None for _ in range(cp_size)]
dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)]
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) ctx.fp8_meta["recipe"], fprop_tensor=False
)
dq_fp8 = torch.empty(
(cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8) dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
fused_attn_dqkv_dtype = dout._fp8_dtype ctx.dO_quantizer = dout._quantizer
dout = dout._data
else: else:
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype fused_attn_dqkv_dtype = dout._fp8_dtype
...@@ -2861,21 +2887,32 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2861,21 +2887,32 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer for i in range(cp_size):
dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
dP_quantizer_per_step[i].amax = amax_per_step[0][i]
dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i]
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
if ctx.fp8_meta is not None and ctx.is_input_fp8: if ctx.fp8_meta is not None:
if ctx.is_input_fp8:
q = ctx.QKV_quantizer.create_tensor_from_data( q = ctx.QKV_quantizer.create_tensor_from_data(
q, fake_dtype=ctx.qkv_dtype, internal=True q, fake_dtype=ctx.qkv_dtype, internal=True
) )
kv = ctx.QKV_quantizer.create_tensor_from_data( kv = ctx.QKV_quantizer.create_tensor_from_data(
kv, fake_dtype=ctx.qkv_dtype, internal=True kv, fake_dtype=ctx.qkv_dtype, internal=True
) )
q, kv = q.dequantize(), kv.dequantize() q = q.dequantize(dtype=ctx.qkv_dtype)
kv = kv.dequantize(dtype=ctx.qkv_dtype)
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
if cp_size_a2a == 1: if cp_size_a2a == 1:
dout = dout.dequantize() dout = dout.dequantize(dtype=dout_dtype)
else:
ctx.dO_quantizer = dout._quantizer
dout = dout._data
dq = torch.empty_like(q) dq = torch.empty_like(q)
p2p_comm_buffers = [ p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
...@@ -2902,9 +2939,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2902,9 +2939,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
True, True,
) )
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True) dout = ctx.dO_quantizer.create_tensor_from_data(
dout = dout.dequantize() dout, fake_dtype=dout_dtype, internal=True
dout = dout._data )
dout = dout.dequantize(dtype=dout_dtype)
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
...@@ -3020,8 +3058,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3020,8 +3058,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True out_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
dout_part = ctx.dO_quantizer.create_tensor_from_data( dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -3133,8 +3173,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3133,8 +3173,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True out_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
dout_part = ctx.dO_quantizer.create_tensor_from_data( dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2, ctx.max_seqlen_kv // 2,
...@@ -3250,8 +3292,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3250,8 +3292,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True out_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
dout_part = ctx.dO_quantizer.create_tensor_from_data( dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q // 2, ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -3282,7 +3326,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3282,7 +3326,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq_ = dq_._data dq_ = dq_._data
dk_ = dk_._data dk_ = dk_._data
dv_ = dv_._data dv_ = dv_._data
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
...@@ -3333,20 +3376,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3333,20 +3376,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
q_part = ctx.QKV_quantizer.create_tensor_from_data( q_part = ctx.QKV_quantizer.create_tensor_from_data(
q_part, fake_dtype=ctx.qkv_dtype q_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
k_part = ctx.QKV_quantizer.create_tensor_from_data( k_part = ctx.QKV_quantizer.create_tensor_from_data(
k_part, fake_dtype=ctx.qkv_dtype k_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
v_part = ctx.QKV_quantizer.create_tensor_from_data( v_part = ctx.QKV_quantizer.create_tensor_from_data(
v_part, fake_dtype=ctx.qkv_dtype v_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
out_part = ctx.O_quantizer.create_tensor_from_data( out_part = ctx.O_quantizer.create_tensor_from_data(
out_part, fake_dtype=ctx.qkv_dtype out_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
dout_part = ctx.dO_quantizer.create_tensor_from_data( dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype dout_part, fake_dtype=dout_dtype, internal=True
) )
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd( dq_, dk_, dv_, dbias_ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -3555,13 +3600,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3555,13 +3600,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv.add_(dkv_) dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax = amax_cp_bwd[0]
ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1]
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8) dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8) dq_fp8, fake_dtype=torch.float32, internal=True
dq, dkv = [x.dequantize() for x in [dq, dkv]] )
dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv_fp8, fake_dtype=torch.float32, internal=True
)
dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]]
dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]
if causal: if causal:
...@@ -3606,9 +3658,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -3606,9 +3658,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
# converting torch.uint8 to float8tensor # converting torch.uint8 to float8tensor
if ctx.fp8 and ctx.is_input_fp8: if ctx.fp8 and ctx.is_input_fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
return ( return (
...@@ -4227,21 +4279,20 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4227,21 +4279,20 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False is_input_fp8 = False
is_output_fp8 = False is_output_fp8 = False
if fp8:
is_output_fp8 = fp8_meta["recipe"].fp8_mha
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
) )
if fp8: if fp8:
if use_fused_attention: if use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance( assert isinstance(k, q.__class__) and isinstance(
v, q.__class__ v, q.__class__
), "q, k, and v must have the same type." ), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if is_input_fp8: if is_input_fp8:
QKV_quantizer = q._quantizer
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
...@@ -4350,31 +4401,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4350,31 +4401,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out = out_fp8._data out = out_fp8._data
else: else:
out_fp8 = O_quantizer.create_tensor_from_data( out_fp8 = O_quantizer.create_tensor_from_data(
out, fake_dtype=qkv_dtype, internal=False out, fake_dtype=qkv_dtype, internal=True
) )
out_f16 = out_fp8.dequantize() out_f16 = out_fp8.dequantize(dtype=qkv_dtype)
out_ret = out_f16 out_ret = out_f16
else: else:
out_ret = out out_ret = out
if fp8: if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, k_save, v_save, out_save = q, k, v, out q_save, k_save, v_save, out_save = q, k, v, out
elif is_input_fp8:
q_fp8 = QKV_quantizer.create_tensor_from_data(
q, fake_dtype=qkv_dtype, internal=False
)
k_fp8 = QKV_quantizer.create_tensor_from_data(
k, fake_dtype=qkv_dtype, internal=False
)
v_fp8 = QKV_quantizer.create_tensor_from_data(
v, fake_dtype=qkv_dtype, internal=False
)
q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out
else: else:
q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 if is_input_fp8:
q_save, k_save, v_save = q, k, v
else: else:
q_save, k_save, v_save, out_save = q, k, v, out q_save, k_save, v_save = q_f16, k_f16, v_f16
if is_output_fp8:
out_save = out
else:
out_save = out_f16
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
q_save, q_save,
...@@ -4397,7 +4441,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4397,7 +4441,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.qkv_dtype = qkv_dtype
ctx.batch_size = batch_size ctx.batch_size = batch_size
ctx.cp_group = cp_group ctx.cp_group = cp_group
...@@ -4436,27 +4479,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4436,27 +4479,24 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
*aux_ctx_tensors, *aux_ctx_tensors,
) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
dout_dtype = dout.dtype
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
seq_dim = ctx.qkv_format.index("s") seq_dim = ctx.qkv_format.index("s")
dout_dtype = dout.dtype
fused_attn_backend = None fused_attn_backend = None
fused_attn_dqkv_dtype = None fused_attn_dqkv_dtype = None
if ctx.fp8: if ctx.fp8:
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_dqkv_dtype = fp8_dtype_backward
if ctx.use_fused_attention: if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
dout_fp8 = dout ctx.dO_quantizer = dout._quantizer
dout = dout_fp8._data
else: else:
dout_f16 = dout dout = ctx.dO_quantizer(dout)
dout = ctx.dO_quantizer(dout_f16)._data fused_attn_dqkv_dtype = dout._fp8_dtype
dout = dout._data
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer
...@@ -4465,12 +4505,25 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4465,12 +4505,25 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
if ctx.fp8_meta is not None and ctx.is_output_fp8: if ctx.fp8_meta is not None:
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]] ctx.dO_quantizer = dout._quantizer
dout = dout._data
if ctx.is_input_fp8:
q = ctx.QKV_quantizer.create_tensor_from_data(
q, fake_dtype=ctx.qkv_dtype, internal=True
)
k = ctx.QKV_quantizer.create_tensor_from_data(
k, fake_dtype=ctx.qkv_dtype, internal=True
)
v = ctx.QKV_quantizer.create_tensor_from_data(
v, fake_dtype=ctx.qkv_dtype, internal=True
)
q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]]
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_dqkv_dtype = TE_DType[dout_dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
...@@ -4481,6 +4534,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4481,6 +4534,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out, dout = flash_attn_a2a_communicate( out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
) )
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
out = ctx.O_quantizer.create_tensor_from_data(
out, fake_dtype=ctx.qkv_dtype, internal=True
)
dout = ctx.dO_quantizer.create_tensor_from_data(
dout, fake_dtype=dout_dtype, internal=True
)
out = out.dequantize(dtype=ctx.qkv_dtype)
dout = dout.dequantize(dtype=dout_dtype)
flash_attn_bwd = None flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
...@@ -4531,7 +4593,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4531,7 +4593,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out_part, fake_dtype=ctx.qkv_dtype, internal=True out_part, fake_dtype=ctx.qkv_dtype, internal=True
) )
dout_part = ctx.dO_quantizer.create_tensor_from_data( dout_part = ctx.dO_quantizer.create_tensor_from_data(
dout_part, fake_dtype=ctx.qkv_dtype, internal=True dout_part, fake_dtype=dout_dtype, internal=True
) )
dq, dk, dv, _ = fused_attn_bwd( dq, dk, dv, _ = fused_attn_bwd(
...@@ -4602,11 +4664,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -4602,11 +4664,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
if ctx.fp8: if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) dq = ctx.dQKV_quantizer.create_tensor_from_data(
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) )
dk = ctx.dQKV_quantizer.create_tensor_from_data(
dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
)
dv = ctx.dQKV_quantizer.create_tensor_from_data(
dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
)
if not ctx.is_input_fp8: if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
return ( return (
......
...@@ -56,7 +56,7 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. ...@@ -56,7 +56,7 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
): ):
return torch.float8_e4m3fn return torch.float8_e4m3fn
return torch.float8_e5m2fn return torch.float8_e5m2
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment