"...git@developer.sourcefind.cn:mkm3sc7k5up7/vllm_017.git" did not exist on "fbeb8a6f13f8d47a7a9af6acb968a5cd3810cb24"
Unverified Commit e14d1472 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix issues in fused_attn_bwd (#1574)



* fix dtypes of fused_attn_bwd in CP+A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dtypes of fused_attn_bwd in CP+P2P
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix amax_per_step
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clone scaling factors of fwd quantizers
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix fwd quantizers of CP+P2P
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* dequantize fp8 out in CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* delete redundant None in FusedAttnFunc bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86813893
...@@ -286,6 +286,12 @@ def run_dpa_with_cp( ...@@ -286,6 +286,12 @@ def run_dpa_with_cp(
else: else:
out_.backward(dout_) out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]: for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x)) assert torch.all(~torch.isinf(x))
......
...@@ -688,9 +688,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -688,9 +688,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# partial result quantizer # partial result quantizer
for i in range(cp_size): for i in range(cp_size):
S_quantizer_per_step[i] = S_quantizer.copy() S_quantizer_per_step[i] = S_quantizer.copy()
S_quantizer_per_step[i].amax = amax_per_step[0][i] S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() O_CP_quantizer_per_step[i] = O_CP_quantizer.copy()
O_CP_quantizer_per_step[i].amax = amax_per_step[1][i] O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
...@@ -1477,8 +1477,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1477,8 +1477,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if fp8 and use_fused_attention: if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1) amax_cp_fwd = amax_per_step.amax(dim=1)
S_quantizer.amax = amax_cp_fwd[0] S_quantizer.amax.copy_(amax_cp_fwd[0])
O_CP_quantizer.amax = amax_cp_fwd[1] O_CP_quantizer.amax.copy_(amax_cp_fwd[1])
out_fp8 = None out_fp8 = None
out_f16 = out.to(qkv_dtype) out_f16 = out.to(qkv_dtype)
...@@ -1511,16 +1511,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1511,16 +1511,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.qkv_dtype = qkv_dtype
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.O_CP_quantizer = O_CP_quantizer
ctx.S_quantizer = S_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.cp_group_a2a = cp_group_a2a ctx.cp_group_a2a = cp_group_a2a
ctx.cp_size_a2a = cp_size_a2a ctx.cp_size_a2a = cp_size_a2a
ctx.rank_a2a = rank_a2a ctx.rank_a2a = rank_a2a
...@@ -1544,6 +1534,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1544,6 +1534,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
ctx.O_quantizer.scale = O_quantizer.scale.clone()
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret return out_ret
...@@ -1630,32 +1636,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1630,32 +1636,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention: if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
dqkv_fp8_torch_dtype = get_fp8_torch_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
dq_fp8 = torch.empty(
(cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device
)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.is_output_fp8: if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.dO_quantizer = dout._quantizer ctx.dO_quantizer = dout._quantizer
else: else:
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
dout = dout._data dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device)
dkv_fp8 = torch.empty(
(cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device
)
dkv_fp8_ = torch.empty_like(dkv_fp8)
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
dout = dout._data
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
for i in range(cp_size): for i in range(cp_size):
dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
dP_quantizer_per_step[i].amax = amax_per_step[0][i] dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy()
dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i] dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
...@@ -1836,7 +1837,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1836,7 +1837,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -1960,7 +1961,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1960,7 +1961,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2088,7 +2089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2088,7 +2089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2193,7 +2194,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2193,7 +2194,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -2393,8 +2394,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2393,8 +2394,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1) amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax = amax_cp_bwd[0] ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1] ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
...@@ -3227,14 +3228,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3227,14 +3228,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.qkv_dtype = qkv_dtype
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.batch_size = batch_size ctx.batch_size = batch_size
ctx.cp_group = cp_group ctx.cp_group = cp_group
ctx.cp_stream = cp_stream ctx.cp_stream = cp_stream
...@@ -3253,6 +3246,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3253,6 +3246,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
ctx.O_quantizer.scale = O_quantizer.scale.clone()
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret return out_ret
...@@ -3289,7 +3297,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3289,7 +3297,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.dO_quantizer = dout._quantizer ctx.dO_quantizer = dout._quantizer
else: else:
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
fused_attn_dqkv_dtype = dout._fp8_dtype fused_attn_dqkv_dtype = TE_DType[dout._data.dtype]
dout = dout._data dout = dout._data
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer
...@@ -3399,7 +3407,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3399,7 +3407,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
v_part, v_part,
out_part, out_part,
dout_part, dout_part,
ctx.qkv_dtype, dout_dtype,
fused_attn_dqkv_dtype, fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
fused_attn_backend, fused_attn_backend,
...@@ -4746,6 +4754,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4746,6 +4754,9 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer ctx.S_quantizer = S_quantizer
if ctx.fp8:
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -4961,8 +4972,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4961,8 +4972,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
...@@ -4993,8 +5002,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4993,8 +5002,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment