Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
81e00bce
Unverified
Commit
81e00bce
authored
Jan 31, 2025
by
Max Podkorytov
Browse files
re-add group-mode kernels
parent
c657f72b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
8 deletions
+5
-8
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
+2
-5
example/ck_tile/18_flexattn/fmha_fwd.cpp
example/ck_tile/18_flexattn/fmha_fwd.cpp
+1
-1
example/ck_tile/18_flexattn/fmha_fwd.hpp
example/ck_tile/18_flexattn/fmha_fwd.hpp
+1
-1
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
+1
-1
No files found.
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
View file @
81e00bce
...
@@ -136,7 +136,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
...
@@ -136,7 +136,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
}}
"""
"""
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
FMHA_FWD_API_INNER_DISPATCH
=
""" {F_if}(
(t.is_group_mode == {F_mode}) &&
(t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
return fmha_fwd_<trait_>(s, a);
...
@@ -284,8 +284,6 @@ class FmhaFwdApiPool:
...
@@ -284,8 +284,6 @@ class FmhaFwdApiPool:
inners
=
str
()
inners
=
str
()
first_k
=
True
first_k
=
True
for
k
,
trait
in
enumerate
(
traits
):
for
k
,
trait
in
enumerate
(
traits
):
if
trait
.
mode
!=
"batch"
:
continue
if
trait
.
dropout
==
"t"
:
if
trait
.
dropout
==
"t"
:
continue
continue
if
trait
.
lse
==
"t"
:
if
trait
.
lse
==
"t"
:
...
@@ -501,8 +499,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
...
@@ -501,8 +499,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
if
d
==
None
:
if
d
==
None
:
continue
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
[
"batch"
]):
tile
=
d
[
hdim_str
]
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
...
...
example/ck_tile/18_flexattn/fmha_fwd.cpp
View file @
81e00bce
...
@@ -849,7 +849,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -849,7 +849,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
// fmha_fwd_traits or fmha_splitkv_traits
else
// fmha_fwd_traits or fmha_splitkv_traits
{
{
//
traits.is_group_mode = (mode == mode_enum::group);
traits
.
is_group_mode
=
(
mode
==
mode_enum
::
group
);
traits
.
mask_type
=
mask
.
type
;
traits
.
mask_type
=
mask
.
type
;
traits
.
bias_type
=
bias
.
type
;
traits
.
bias_type
=
bias
.
type
;
// traits.has_lse = lse;
// traits.has_lse = lse;
...
...
example/ck_tile/18_flexattn/fmha_fwd.hpp
View file @
81e00bce
...
@@ -782,7 +782,7 @@ struct fmha_fwd_traits
...
@@ -782,7 +782,7 @@ struct fmha_fwd_traits
int
hdim_q
;
int
hdim_q
;
int
hdim_v
;
int
hdim_v
;
std
::
string
data_type
;
std
::
string
data_type
;
//
bool is_group_mode;
bool
is_group_mode
;
bool
is_v_rowmajor
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
...
...
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
View file @
81e00bce
...
@@ -40,7 +40,7 @@ run_fp16_bf16_tests() {
...
@@ -40,7 +40,7 @@ run_fp16_bf16_tests() {
fi
fi
for
prec
in
"bf16"
;
do
for
prec
in
"bf16"
;
do
for
mode
in
0
;
do
for
mode
in
0
1
;
do
for
perm
in
0
;
do
for
perm
in
0
;
do
for
vlayout
in
"r"
;
do
for
vlayout
in
"r"
;
do
for
hdim
in
32 64 128 256
;
do
for
hdim
in
32 64 128 256
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment