Unverified Commit 88649838 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Fix issues (#515)



* fix cudnn FA softmax shape
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* set inplace rng_state
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent ea43b18e
...@@ -472,7 +472,12 @@ def fused_attn_fwd_qkvpacked( ...@@ -472,7 +472,12 @@ def fused_attn_fwd_qkvpacked(
out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype) out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
if is_training: if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else: else:
softmax_aux = None softmax_aux = None
...@@ -631,7 +636,12 @@ def fused_attn_fwd_kvpacked( ...@@ -631,7 +636,12 @@ def fused_attn_fwd_kvpacked(
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training: if is_training:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else: else:
softmax_aux = None softmax_aux = None
......
...@@ -1271,14 +1271,15 @@ PD_BUILD_OP(te_rmsnorm_bwd) ...@@ -1271,14 +1271,15 @@ PD_BUILD_OP(te_rmsnorm_bwd)
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"), .Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"),
"rng_state"}) "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")}) .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"}) "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"}, .SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}}) {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
...@@ -1293,15 +1294,16 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) ...@@ -1293,15 +1294,16 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
PD_BUILD_OP(te_fused_attn_fwd_kvpacked) PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "rng_state"}) paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux")}) .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"}) "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"}, .SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}}) {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked));
PD_BUILD_OP(te_fused_attn_bwd_kvpacked) PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment