Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c1e2fef7
Commit
c1e2fef7
authored
Feb 10, 2025
by
rocking
Browse files
Add receipt for aiter integration
parent
052a7265
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
4 deletions
+54
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+18
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+16
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+15
-1
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+5
-2
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
c1e2fef7
...
@@ -499,13 +499,30 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -499,13 +499,30 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond
&=
dpad
==
dvpad
cond
&=
dpad
==
dvpad
if
not
cond
:
if
not
cond
:
continue
continue
if
receipt
==
3
:
el
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
dpad
==
dvpad
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
continue
continue
elif
receipt
==
10
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"batch"
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
dropout
in
[
'no'
,
'dropout_wg32'
,
'dropout_wg16'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
elif
receipt
==
11
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"group"
cond
&=
dropout
in
[
'no'
,
'dropout_wg32'
,
'dropout_wg16'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
c1e2fef7
...
@@ -494,6 +494,22 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -494,6 +494,22 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond
&=
pipeline
.
F_squant
==
'f'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
if
not
cond
:
continue
continue
elif
receipt
==
10
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"batch"
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
elif
receipt
==
11
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"group"
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
c1e2fef7
...
@@ -268,7 +268,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
...
@@ -268,7 +268,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
...
@@ -712,6 +712,20 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -712,6 +712,20 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond
&=
pipeline
.
F_squant
==
'f'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
if
not
cond
:
continue
continue
elif
receipt
==
11
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"group"
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
elif
receipt
==
14
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"batch"
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
...
...
example/ck_tile/01_fmha/generate.py
View file @
c1e2fef7
...
@@ -17,7 +17,7 @@ class HandlerId(IntEnum):
...
@@ -17,7 +17,7 @@ class HandlerId(IntEnum):
LIST_BLOBS
=
0
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
WRITE_BLOBS
=
1
# inspect all modules under 'codegen.ops' and register API handlers
# inspect all modules under 'codegen.ops' and register API handlers
ops
=
[]
ops
=
[]
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
...
@@ -103,7 +103,10 @@ if __name__ == "__main__":
...
@@ -103,7 +103,10 @@ if __name__ == "__main__":
required
=
False
,
required
=
False
,
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 2: Only generate instance for Flash attention integration"
" 2: Only generate instance for Flash attention integration"
+
\
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration"
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration"
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment