Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9d8d4c61
Commit
9d8d4c61
authored
Dec 17, 2024
by
Po Yen Chen
Browse files
Use larger tile size for chunk prefill
parent
d46196f2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
27 deletions
+38
-27
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8
-8
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+30
-19
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
9d8d4c61
...
@@ -409,17 +409,17 @@ class FmhaFwdKernel:
...
@@ -409,17 +409,17 @@ class FmhaFwdKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
#
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32,
96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
32
,
-
1
)
}
}
else
:
else
:
return
None
return
None
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
9d8d4c61
...
@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
...
@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
from
codegen.cmake_config
import
*
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.cpp_symbol_map
import
*
import
codegen.ops.fmha_fwd
from
codegen.ops.fmha_fwd
import
(
from
codegen.ops.fmha_fwd
import
(
FmhaFwdTileSize
,
FmhaFwdTileSize
,
FmhaFwdApiTrait
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
FMHA_FWD_API_PER_HDIM_CASE
,
...
@@ -574,14 +574,14 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -574,14 +574,14 @@ class FmhaFwdSplitKVCombineKernel:
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
[
str
,
FmhaFwdTileSize
]
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
#
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
...
@@ -592,20 +592,21 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...
@@ -592,20 +592,21 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else
:
else
:
return
None
return
None
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
[
str
,
List
[
FmhaFwdSplitKVCombineTileSize
]]
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
# tile size for decode tile size for prefill
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
'32'
:
[
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
16
,
-
1
)],
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'64'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
)],
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
### '96' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
'128'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
)],
'256'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
)],
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'64'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
)
]
,
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'128'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
)
]
,
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'256'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
)
]
,
}
}
else
:
else
:
return
None
return
None
...
@@ -655,18 +656,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -655,18 +656,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
for
dtype
in
FWD_DTYPE_MAP
.
keys
():
for
dtype
in
FWD_DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
prefill_tiles
=
codegen
.
ops
.
fmha_fwd
.
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
decode_tiles
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
decode_tiles
==
None
:
continue
continue
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
assert
all
(
tile
in
prefill_tiles
.
keys
()
for
tile
in
decode_tiles
.
keys
())
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
for
hdim_str
,
mode
in
itertools
.
product
(
decode_tiles
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
prefill_tile
=
prefill_tiles
[
hdim_str
]
decode_tile
=
decode_tiles
[
hdim_str
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"group"
:
if
mode
==
"group"
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
is_prefill
=
(
mode
==
"group"
and
pipeline
.
F_pagedkv
==
't'
)
tile
=
prefill_tile
if
is_prefill
else
decode_tile
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
@@ -720,9 +731,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
...
@@ -720,9 +731,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
continue
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
tile
s
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
for
tile
,
pipeline
in
itertools
.
product
(
tiles
,
get_pipelines
(
dtype
,
hdim
)
)
:
if
mode
==
"group"
:
if
mode
==
"group"
:
if
pipeline
.
F_spad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment