Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9d772b9a
"python/vscode:/vscode.git/clone" did not exist on "7b9d9fc0527ca9798f0a380ef0443fa550a03f60"
Commit
9d772b9a
authored
Feb 12, 2025
by
Jim
Browse files
fix padding case
parent
14b4d6bb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+5
-4
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
9d772b9a
...
...
@@ -628,12 +628,12 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
=
""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3
_hdp
_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3
{F_padding_suffix}
_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;"""
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
=
""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3
{F_padding_suffix}
_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;"""
FMHA_BWD_V3_PER_DTYPE_CASE
=
""" {F_if} (t.data_type.compare(
\"
{F_dtype}
\"
) == 0) {{
...
...
@@ -811,10 +811,11 @@ class FmhaBwdApiPool:
if_m
=
'if'
if
m
==
0
else
'else if'
inners
=
str
()
bf16_cvt_tmp
=
0
if
dtype
==
"fp16"
else
bf16_cvt
padding_suffix
=
"_hdp"
if
BWD_V3_PADDING_CHECK_MAP
[
m
]
==
"true"
else
""
if
is_atomic
==
"t"
:
inners
=
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
inners
=
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
]
,
F_padding_suffix
=
padding_suffix
)
else
:
inners
=
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
])
inners
=
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH
.
format
(
F_hdim
=
hdim
,
F_dtype
=
BWD_DTYPE_MAP
[
dtype
],
F_is_causal
=
BOOL_MAP
[
is_causal
],
F_is_atomic32
=
BOOL_MAP
[
is_atomic
],
F_how_v3_bf16_cvt
=
bf16_cvt_tmp
,
F_padding
=
BWD_V3_PADDING_CHECK_MAP
[
m
]
,
F_padding_suffix
=
padding_suffix
)
per_hdim
=
per_hdim
+
FMHA_BWD_V3_PER_HDIM_CASE
.
format
(
F_if
=
if_m
,
F_hdim_expression
=
BWD_V3_HDIM_CASE_MAP
[
m
],
inner_dispatch
=
inners
)
if_l
=
'if'
if
l
==
0
else
'else if'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment