Commit c881136b authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into ck_tile/support-vllm-kcache-layout

parents c5e8e14f 4e076909
...@@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ...@@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd`
## kernel ## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
There are 3 template parameters for this kernel template. There are 2 template parameters for this kernel template.
* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose.
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
......
...@@ -119,6 +119,7 @@ PIPELINE_MAP = { ...@@ -119,6 +119,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP = { PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
} }
BOOL_MAP = { BOOL_MAP = {
......
...@@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = { ...@@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = {
256: 256 256: 256
} }
TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py // auto generated by generate.py
...@@ -44,13 +39,12 @@ FMHA_FWD_KERNEL_BODY=""" ...@@ -44,13 +39,12 @@ FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx}, using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}, ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}, ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>; {F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
...@@ -91,9 +85,7 @@ using fmha_epilogue_{F_idx} = ...@@ -91,9 +85,7 @@ using fmha_epilogue_{F_idx} =
{F_spad}, {F_dvpad}>>; {F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>, ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
...@@ -306,15 +298,19 @@ class FmhaFwdTileSize: ...@@ -306,15 +298,19 @@ class FmhaFwdTileSize:
F_rm1 : int # number of warps for gemm1 along q seqlen F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used) F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm : int # warp size along m (warp size) F_wm0 : int # gemm0 warp size along m
F_wn : int # warp size along n F_wn0 : int # gemm0 warp size along n
F_wk : int # warp size along k F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass @dataclass
class FmhaFwdKernel: class FmhaFwdKernel:
...@@ -326,12 +322,6 @@ class FmhaFwdKernel: ...@@ -326,12 +322,6 @@ class FmhaFwdKernel:
F_pipeline : FmhaFwdPipeline F_pipeline : FmhaFwdPipeline
mask_impl : str mask_impl : str
def get_tp(self) -> str:
if self.F_mode == 'group':
return 'hbs'
else:
return 'shb'
@property @property
def template(self) -> str: def template(self) -> str:
kernel_body = str() kernel_body = str()
...@@ -352,9 +342,12 @@ class FmhaFwdKernel: ...@@ -352,9 +342,12 @@ class FmhaFwdKernel:
F_rm1 = self.F_tile.F_rm1, F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1, F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1, F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm0 = self.F_tile.F_wm0,
F_wn = self.F_tile.F_wn, F_wn0 = self.F_tile.F_wn0,
F_wk = self.F_tile.F_wk, F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
...@@ -368,13 +361,12 @@ class FmhaFwdKernel: ...@@ -368,13 +361,12 @@ class FmhaFwdKernel:
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
@property @property
def name(self) -> str: def name(self) -> str:
# TODO: we don't encode idx here # TODO: we don't encode idx here
return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name self.F_tile.name + '_' + self.F_pipeline.name
@property @property
...@@ -409,17 +401,17 @@ class FmhaFwdKernel: ...@@ -409,17 +401,17 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} }
else: else:
return None return None
......
...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl ...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>; fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
......
...@@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = { ...@@ -39,6 +39,7 @@ K0_MAX_SUBMAX_MAP = {
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
} }
...@@ -50,13 +51,12 @@ namespace {{ ...@@ -50,13 +51,12 @@ namespace {{
template <bool kIsMultipleSplits, bool kHasUnevenSplits = kIsMultipleSplits> template <bool kIsMultipleSplits, bool kHasUnevenSplits = kIsMultipleSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile, ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile, ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>; {F_vlayout}>;
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...@@ -175,9 +175,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< ...@@ -175,9 +175,8 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim}, {F_hdim},
{F_bm0},
{F_bn1},
{F_mode}, {F_mode},
{F_bn1},
fmha_trait>; fmha_trait>;
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
...@@ -191,9 +190,7 @@ using fmha_epilogue = ...@@ -191,9 +190,7 @@ using fmha_epilogue =
false, false>>; false, false>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>, ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
...@@ -206,7 +203,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -206,7 +203,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -279,16 +276,25 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F ...@@ -279,16 +276,25 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
if (1 < a.num_splits) {{ if (1 < a.num_splits) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{ if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1; return -1;
}} else {{ }} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
}} else {{ }} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
...@@ -346,7 +352,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -346,7 +352,7 @@ class FmhaFwdSplitKVApiTrait:
if self.pipeline_tag == 'qr_async': if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support if self.spad == 't' : return 'true' # always support
else : return 'true' else : return 'true'
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0' else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False else: assert False
...@@ -357,7 +363,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -357,7 +363,7 @@ class FmhaFwdSplitKVApiTrait:
if self.pipeline_tag == 'qr_async': if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']: elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False else: assert False
...@@ -368,7 +374,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -368,7 +374,7 @@ class FmhaFwdSplitKVApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype]) vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0' else : return f'a.hdim_q % {bk0submax} == 0'
...@@ -380,7 +386,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -380,7 +386,7 @@ class FmhaFwdSplitKVApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype]) vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0' else : return f'a.hdim_v % {bk0submax} == 0'
...@@ -491,12 +497,11 @@ class FmhaFwdSplitKVApiPool: ...@@ -491,12 +497,11 @@ class FmhaFwdSplitKVApiPool:
@dataclass @dataclass
class FmhaFwdSplitKVCombineTileSize: class FmhaFwdSplitKVCombineTileSize:
F_bm0 : int # tile size along q seqlen
F_bn1 : int # tile size along v head_dim F_bn1 : int # tile size along v head_dim
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn1}" +\ return f"b{self.F_bn1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass @dataclass
...@@ -529,9 +534,12 @@ class FmhaFwdSplitKVKernel: ...@@ -529,9 +534,12 @@ class FmhaFwdSplitKVKernel:
F_rm1 = self.F_tile.F_rm1, F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1, F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1, F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm0 = self.F_tile.F_wm0,
F_wn = self.F_tile.F_wn, F_wn0 = self.F_tile.F_wn0,
F_wk = self.F_tile.F_wk, F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
...@@ -597,7 +605,6 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -597,7 +605,6 @@ class FmhaFwdSplitKVCombineKernel:
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
...@@ -621,17 +628,17 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -621,17 +628,17 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} }
else: else:
return None return None
...@@ -639,18 +646,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -639,18 +646,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
# tile size for decode tile size for prefill '32' : FmhaFwdSplitKVCombineTileSize(32, -1),
'32' : [FmhaFwdSplitKVCombineTileSize(16, 16, -1), FmhaFwdSplitKVCombineTileSize(64, 16, -1)], '64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'64' : [FmhaFwdSplitKVCombineTileSize(32, 32, -1), FmhaFwdSplitKVCombineTileSize(64, 32, -1)], ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
### '96' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)], '128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : [FmhaFwdSplitKVCombineTileSize(32, 64, -1), FmhaFwdSplitKVCombineTileSize(64, 64, -1)], '256' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : [FmhaFwdSplitKVCombineTileSize(32, 128, -1), FmhaFwdSplitKVCombineTileSize(64, 128, -1)],
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : [FmhaFwdSplitKVCombineTileSize(64, 32, -1)], '64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : [FmhaFwdSplitKVCombineTileSize(64, 64, -1)], '128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : [FmhaFwdSplitKVCombineTileSize(64, 128, -1)], '256' : FmhaFwdSplitKVCombineTileSize(32, -1),
} }
else: else:
return None return None
...@@ -710,7 +716,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -710,7 +716,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles # make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
assert all(tile in prefill_tiles.keys() for tile in decode_tiles.keys()) assert all(tile in prefill_tiles.keys() for tile in decode_tiles.keys())
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(decode_tiles.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(decode_tiles.keys(), MODE_MAP.keys()):
prefill_tile = prefill_tiles[hdim_str] prefill_tile = prefill_tiles[hdim_str]
decode_tile = decode_tiles[hdim_str] decode_tile = decode_tiles[hdim_str]
...@@ -778,12 +783,10 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis ...@@ -778,12 +783,10 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
# include prefill tile size if in group mode tile = d[hdim_str]
tiles = d[hdim_str][0 : 2 if mode == 'group' else 1]
hdim = int(hdim_str) hdim = int(hdim_str)
for tile, pipeline in itertools.product(tiles, get_pipelines(dtype, hdim)): for pipeline in get_pipelines(dtype, hdim):
if mode == 'group': if mode == 'group':
if pipeline.F_spad != 't': if pipeline.F_spad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
......
...@@ -401,8 +401,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -401,8 +401,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
} }
}(); }();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); if constexpr(FmhaKernel::kIsGroupMode)
{
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
}
} }
template <typename Kernel> template <typename Kernel>
...@@ -712,7 +722,6 @@ std::string fmha_fwd_splitkv_get_name_(); ...@@ -712,7 +722,6 @@ std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_, ck_tile::index_t kN1_,
bool kStoreLse_, bool kStoreLse_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
...@@ -723,7 +732,6 @@ struct fmha_fwd_splitkv_combine_traits_ ...@@ -723,7 +732,6 @@ struct fmha_fwd_splitkv_combine_traits_
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>; using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
......
...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType; ...@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "3840", "m dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[]) ...@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -64,7 +64,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -64,7 +64,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#endif #endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif #endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \ const dim3 grids = kernel::GridSize(a); \
...@@ -15,6 +17,28 @@ ...@@ -15,6 +17,28 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -17,3 +17,4 @@ $EXE -t=71 -e=11 -k=11 ...@@ -17,3 +17,4 @@ $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \ const dim3 grids = kernel::GridSize(a); \
...@@ -15,6 +17,28 @@ ...@@ -15,6 +17,28 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[]) ...@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count, ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
batch_count, batch_count,
kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
......
...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; ...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default") arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("Ns", "", "N dimensions - empty by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("Ks", "", "K dimensions - empty by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU") .insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("group_count", "16", "group count"); .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
return -1; return -1;
}; };
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count"); const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat"); const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup"); const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms; std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns; std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks; std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As; std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs; std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs; std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
Ms.push_back(256 + 256 * i); Ms.push_back(256 + 256 * i);
...@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_Bs.push_back(Ks[i]); stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]); stride_Cs.push_back(Ns[i]);
} }
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors; std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors; std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
......
...@@ -115,8 +115,8 @@ ...@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif #endif
#ifndef DCK_USE_OCP_FP8 #ifndef CK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@ #cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif #endif
#ifndef CK_USE_FNUZ_FP8 #ifndef CK_USE_FNUZ_FP8
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert( static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment