"vscode:/vscode.git/clone" did not exist on "8ef8285c40542c8c3724f9b3eadbb006793958f0"
Unverified Commit 9437ceb2 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix QKV dtype in the bwd of FP8+CP (#1134)



* fix qkv_dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* config cp correction dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code style change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* always do FP8 CP correction in FP32
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent aecd5a8f
...@@ -2261,8 +2261,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2261,8 +2261,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_backward fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
...@@ -2304,7 +2305,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2304,7 +2305,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
out = out.view(*q.shape) out = out.view(*q.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment