"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "262c184eb8331dcb50037477881900e46bd5c5f2"
Unverified Commit 9437ceb2 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Fix QKV dtype in the bwd of FP8+CP (#1134)



* fix qkv_dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* config cp correction dtype of FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code style change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* always do FP8 CP correction in FP32
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent aecd5a8f
......@@ -2261,8 +2261,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_backward
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
......@@ -2304,7 +2305,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
out = out.view(*q.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment