Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b9dc91cc
Commit
b9dc91cc
authored
Dec 17, 2024
by
Po Yen Chen
Browse files
Merge branch 'feature/use-larger-tile-size-for-chunk-prefill' into feature/add-splitkv-instance
parents
ed634ea4
ff8d3c96
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
122 additions
and
88 deletions
+122
-88
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8
-8
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+30
-18
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+48
-62
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+36
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
b9dc91cc
...
...
@@ -411,7 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
#
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32,
96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
}
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
b9dc91cc
...
...
@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
import
codegen.ops.fmha_fwd
from
codegen.ops.fmha_fwd
import
(
FmhaFwdTileSize
,
FmhaFwdApiTrait
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
...
...
@@ -610,9 +610,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return
{
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
#
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
...
...
@@ -626,17 +626,18 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
# tile size for decode tile size for prefill
'32'
:
[
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
16
,
-
1
)],
'64'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
)],
### '96' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)],
'128'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
)],
'256'
:
[
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
)],
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'64'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
)
]
,
'128'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
)
]
,
'256'
:
[
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
)
]
,
}
else
:
return
None
...
...
@@ -689,18 +690,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
for
dtype
in
FWD_DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
prefill_tiles
=
codegen
.
ops
.
fmha_fwd
.
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
decode_tiles
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
decode_tiles
==
None
:
continue
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
assert
all
(
tile
in
prefill_tiles
.
keys
()
for
tile
in
decode_tiles
.
keys
())
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
for
hdim_str
,
mode
in
itertools
.
product
(
decode_tiles
.
keys
(),
MODE_MAP
.
keys
()):
prefill_tile
=
prefill_tiles
[
hdim_str
]
decode_tile
=
decode_tiles
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"group"
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
is_prefill
=
(
mode
==
"group"
and
pipeline
.
F_pagedkv
==
't'
)
tile
=
prefill_tile
if
is_prefill
else
decode_tile
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
...
...
@@ -754,10 +765,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
# include prefill tile size if in group mode
tiles
=
d
[
hdim_str
][
0
:
2
if
mode
==
'group'
else
1
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"
group
"
:
for
tile
,
pipeline
in
itertools
.
product
(
tiles
,
get_pipelines
(
dtype
,
hdim
)
)
:
if
mode
==
'
group
'
:
if
pipeline
.
F_spad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
b9dc91cc
...
...
@@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <map>
#include <numeric>
#include <ostream>
#include <string>
...
...
@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
{
float
n_waves
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
int
override_num_splits_if_necessary
(
int
batch
,
int
nhead
,
int
max_seqlen_q
,
int
hdim_v
,
float
p_drop
,
int
num_splits
)
int
override_num_splits_if_necessary
(
int
batch
,
int
nhead
,
int
max_seqlen_q
,
int
hdim_q
,
int
hdim_v
,
float
p_drop
,
bool
is_prefill
,
int
num_splits
)
{
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
...
...
@@ -246,17 +200,42 @@ int override_num_splits_if_necessary(
return
num_splits
;
}
// tile size should match the generate.py
const
int
kM0
=
64
;
const
int
kN1
=
hdim_v
;
const
int
kM0
=
[
&
]
{
// get kM0 for prefill phase
if
(
is_prefill
)
{
return
128
;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const
std
::
map
<
int
,
int
>
hdim_to_m0
=
{
{
32
,
32
},
{
64
,
64
},
// {96, 64},
{
128
,
64
},
{
256
,
64
},
};
for
(
auto
[
hdim
,
m0
]
:
hdim_to_m0
)
{
if
(
hdim_q
<=
hdim
&&
hdim_v
<=
hdim
)
{
return
m0
;
}
}
return
64
;
// meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;
const
int
num_m_blocks
=
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
);
const
int
num_n_blocks
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
//
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// always 1
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
{
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
16
);
}
return
num_splits
;
...
...
@@ -556,8 +535,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
if
(
num_splits
<
1
)
{
num_splits
=
override_num_splits_if_necessary
(
batch
,
nhead
,
max_seqlen_q
,
hdim_v
,
p_drop
,
num_splits
);
num_splits
=
override_num_splits_if_necessary
(
batch
,
nhead
,
max_seqlen_q
,
hdim_q
,
hdim_v
,
p_drop
,
/*is_prefill=*/
mode
==
mode_enum
::
group
&&
0
<
page_block_size
,
num_splits
);
}
if
(
128
<
num_splits
)
{
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
b9dc91cc
...
...
@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
template
<
typename
Int
=
int
>
Int
num_splits_heuristic
(
Int
batch_nhead_mblocks
,
Int
num_SMs
,
Int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
for
(
Int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
*=
2
)
{
float
n_blocks
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_blocks
/
std
::
ceil
(
n_blocks
);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
for
(
Int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
return
num_splits
;
}
}
return
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment