Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
510ff45f
Unverified
Commit
510ff45f
authored
Jan 29, 2025
by
Max Podkorytov
Browse files
unhardcode score_mod and pass it as a cpp expression from codegen
parent
27a2a0a1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
23 deletions
+42
-23
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
+22
-7
example/ck_tile/18_flexattn/generate.py
example/ck_tile/18_flexattn/generate.py
+15
-6
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
+5
-10
No files found.
example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py
View file @
510ff45f
...
@@ -84,8 +84,20 @@ using fmha_epilogue_{F_idx} =
...
@@ -84,8 +84,20 @@ using fmha_epilogue_{F_idx} =
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
{F_spad}, {F_dvpad}>>;
struct score_mod_def_{F_idx} {{
using TScore = typename fmha_pipeline_{F_idx}::SaccDataType;
CK_TILE_HOST_DEVICE TScore operator()(TScore s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) const {{
(void) s; (void) h; (void) b; (void) q_idx; (void) v_idx;
return {F_score_mod_expr};
}}
}};
using fmha_kernel_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}
, score_mod_def_{F_idx}
>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
...
@@ -337,6 +349,7 @@ class FmhaFwdKernel:
...
@@ -337,6 +349,7 @@ class FmhaFwdKernel:
F_mode
:
str
# value from MODE_MAP
F_mode
:
str
# value from MODE_MAP
F_tile
:
FmhaFwdTileSize
F_tile
:
FmhaFwdTileSize
F_pipeline
:
FmhaFwdPipeline
F_pipeline
:
FmhaFwdPipeline
F_score_mod_expr
:
str
mask_impl
:
str
mask_impl
:
str
@
property
@
property
...
@@ -378,7 +391,8 @@ class FmhaFwdKernel:
...
@@ -378,7 +391,8 @@ class FmhaFwdKernel:
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_pipeline
=
PIPELINE_MAP
[
self
.
F_pipeline
.
tag
])
F_pipeline
=
PIPELINE_MAP
[
self
.
F_pipeline
.
tag
],
F_score_mod_expr
=
self
.
F_score_mod_expr
)
@
property
@
property
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
...
@@ -433,7 +447,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...
@@ -433,7 +447,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else
:
else
:
return
None
return
None
def
get_fwd_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdApiPool
,
List
[
FmhaFwdKernel
]]:
def
get_fwd_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
,
score_mod_expr
:
str
)
->
Tuple
[
FmhaFwdApiPool
,
List
[
FmhaFwdKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdPipeline
]:
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdPipeline
]:
...
@@ -502,6 +516,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -502,6 +516,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
F_mode
=
mode
,
F_mode
=
mode
,
F_tile
=
tile
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
F_pipeline
=
pipeline
,
F_score_mod_expr
=
score_mod_expr
,
mask_impl
=
mask_impl
)
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
...
@@ -524,15 +539,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
...
@@ -524,15 +539,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def
write_fwd_api
(
api_pool
:
FmhaFwdApiPool
,
autogen_dir
:
Path
)
->
None
:
def
write_fwd_api
(
api_pool
:
FmhaFwdApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
(
autogen_dir
/
FMHA_FWD_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
,
score_mod_expr
)
->
None
:
api_pool
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
api_pool
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
,
score_mod_expr
)
for
kernel
in
kernels
:
for
kernel
in
kernels
:
write_single_fwd_kernel
(
kernel
,
output_dir
)
write_single_fwd_kernel
(
kernel
,
output_dir
)
write_fwd_api
(
api_pool
,
output_dir
)
write_fwd_api
(
api_pool
,
output_dir
)
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
,
score_mod_expr
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
with
file_path
.
open
(
'a'
)
as
f
:
_
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
_
,
kernels
=
get_fwd_blobs
(
kernel_filter
,
receipt
,
mask_impl
,
score_mod_expr
)
for
kernel
in
kernels
:
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_API_FILENAME
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/18_flexattn/generate.py
View file @
510ff45f
...
@@ -30,7 +30,7 @@ handlers = dict(
...
@@ -30,7 +30,7 @@ handlers = dict(
)
)
assert
0
<
len
(
handlers
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
,
score_mod_expr
)
->
None
:
if
output_dir
is
None
:
if
output_dir
is
None
:
output_dir
=
Path
(
__file__
).
parent
output_dir
=
Path
(
__file__
).
parent
else
:
else
:
...
@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
...
@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
for
api
in
api_list
:
for
api
in
api_list
:
handler
=
handlers
[
api
][
HandlerId
.
WRITE_BLOBS
]
handler
=
handlers
[
api
][
HandlerId
.
WRITE_BLOBS
]
handler
(
output_dir
,
kernel_filter
,
receipt
,
mask_impl
)
handler
(
output_dir
,
kernel_filter
,
receipt
,
mask_impl
,
score_mod_expr
)
# list all the files that will be generated
# list all the files that will be generated
def
list_blobs
(
output_file
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
list_blobs
(
output_file
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
,
score_mod_expr
)
->
None
:
assert
output_file
is
not
None
assert
output_file
is
not
None
file_path
=
Path
(
output_file
)
file_path
=
Path
(
output_file
)
...
@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
...
@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
for
api
in
api_list
:
for
api
in
api_list
:
handler
=
handlers
[
api
][
HandlerId
.
LIST_BLOBS
]
handler
=
handlers
[
api
][
HandlerId
.
LIST_BLOBS
]
handler
(
file_path
,
kernel_filter
,
receipt
,
mask_impl
)
handler
(
file_path
,
kernel_filter
,
receipt
,
mask_impl
,
score_mod_expr
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -106,9 +106,18 @@ if __name__ == "__main__":
...
@@ -106,9 +106,18 @@ if __name__ == "__main__":
" 2: Only generate instance for Flash attention integration"
" 2: Only generate instance for Flash attention integration"
)
)
parser
.
add_argument
(
"--score_mod_expr"
,
default
=
"s"
,
# test with
# default="s + static_cast<decltype(s)>(q_idx - v_idx)"
required
=
False
,
help
=
"flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
api_list
=
args
.
direction
.
split
(
','
)
api_list
=
args
.
direction
.
split
(
','
)
if
args
.
list_blobs
is
not
None
:
if
args
.
list_blobs
is
not
None
:
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
,
score_mod_expr
=
args
.
score_mod_expr
)
else
:
else
:
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
,
score_mod_expr
=
args
.
score_mod_expr
)
include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp
View file @
510ff45f
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
,
typename
ScoreModFunction_
>
struct
FmhaFwdKernel
struct
FmhaFwdKernel
{
{
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
...
@@ -1302,16 +1302,11 @@ struct FmhaFwdKernel
...
@@ -1302,16 +1302,11 @@ struct FmhaFwdKernel
}
}
}();
}();
auto
score_mod_def
=
[](
auto
s
,
// may have state inside
ck_tile
::
index_t
b
,
auto
score_mod_def
=
ScoreModFunction_
{};
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
h
;
(
void
)
b
;
return
s
+
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
auto
score_mod_arg
=
[
b
=
i_batch
,
h
=
i_nhead
,
score_mod_def
](
auto
s
,
auto
score_mod_arg
=
[
b
=
i_batch
,
h
=
i_nhead
,
score_mod_def
](
typename
ScoreModFunction_
::
TScore
s
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
ck_tile
::
index_t
v_idx
)
{
return
score_mod_def
(
s
,
b
,
h
,
q_idx
,
v_idx
);
return
score_mod_def
(
s
,
b
,
h
,
q_idx
,
v_idx
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment