Unverified Commit 88649838 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Fix issues (#515)



* fix cudnn FA softmax shape
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* set inplace rng_state
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent ea43b18e
......@@ -472,7 +472,12 @@ def fused_attn_fwd_qkvpacked(
out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
......@@ -631,7 +636,12 @@ def fused_attn_fwd_kvpacked(
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
......
......@@ -1271,14 +1271,15 @@ PD_BUILD_OP(te_rmsnorm_bwd)
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"),
"rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")})
"_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
......@@ -1293,15 +1294,16 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")})
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked));
PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment