Unverified Commit 86322218 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

[CK_TILE] Add fmha fwd headdim96 support (#1608)



* Add ceil_to_qualified_tile_length()

* Rename kK0BlockLength to kQKHeaddim

* Add kSubQKHeaddim concept to support headdim96

* Fix in math.hpp to avoid using __half interfaces

* Add LdsBufferSequence instance for headdim96

* Update in fmha_fwd/fmha_fwd_splitkv codegen to support hd96 testing

* Disable hd96 instance generation in codegen fmha_fwd and fmha_fwd_splitkv to save compiling time

* Reformat one file

* Fix text alignment in fmha_fwd_splitkv.py

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 4d7e063a
...@@ -21,6 +21,14 @@ DTYPE_BITS = { ...@@ -21,6 +21,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
TILE_PARTITIONER_MAP = { TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
...@@ -35,7 +43,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT ...@@ -35,7 +43,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_FWD_KERNEL_BODY=""" FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx}, using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
...@@ -87,7 +95,7 @@ using fmha_kernel_{F_idx} = ...@@ -87,7 +95,7 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx}, fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -125,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -125,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
}} }}
""" """
...@@ -142,7 +150,7 @@ class FmhaFwdApiTrait: ...@@ -142,7 +150,7 @@ class FmhaFwdApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -156,7 +164,7 @@ class FmhaFwdApiTrait: ...@@ -156,7 +164,7 @@ class FmhaFwdApiTrait:
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property @property
...@@ -188,8 +196,9 @@ class FmhaFwdApiTrait: ...@@ -188,8 +196,9 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -199,8 +208,9 @@ class FmhaFwdApiTrait: ...@@ -199,8 +208,9 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -271,7 +281,7 @@ class FmhaFwdApiPool: ...@@ -271,7 +281,7 @@ class FmhaFwdApiPool:
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -289,7 +299,7 @@ class FmhaFwdTileSize: ...@@ -289,7 +299,7 @@ class FmhaFwdTileSize:
F_bk0 : int # tile size along qk gemm unroll F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll F_bk1 : int # tile size along kv gemm unroll
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used) F_rk0 : int # number of warps for gemm0 along head dim q (not used)
...@@ -302,7 +312,7 @@ class FmhaFwdTileSize: ...@@ -302,7 +312,7 @@ class FmhaFwdTileSize:
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
...@@ -335,7 +345,7 @@ class FmhaFwdKernel: ...@@ -335,7 +345,7 @@ class FmhaFwdKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0, F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0, F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0, F_rk0 = self.F_tile.F_rk0,
...@@ -382,7 +392,7 @@ class FmhaFwdKernel: ...@@ -382,7 +392,7 @@ class FmhaFwdKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -401,6 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -401,6 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
} }
......
...@@ -29,6 +29,14 @@ DTYPE_BITS = { ...@@ -29,6 +29,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
...@@ -41,7 +49,7 @@ using fmha_mask_{F_idx} = {F_mask}; ...@@ -41,7 +49,7 @@ using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
...@@ -103,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -103,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>; {F_dvpad}>;
...@@ -241,7 +249,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -241,7 +249,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
...@@ -260,7 +268,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -260,7 +268,7 @@ class FmhaFwdSplitKVApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -274,7 +282,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -274,7 +282,7 @@ class FmhaFwdSplitKVApiTrait:
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}' f'{self.dvpad}-{self.pagedkv}'
...@@ -307,8 +315,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -307,8 +315,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -318,8 +327,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -318,8 +327,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -414,7 +424,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -414,7 +424,7 @@ class FmhaFwdSplitKVApiPool:
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -458,7 +468,7 @@ class FmhaFwdSplitKVKernel: ...@@ -458,7 +468,7 @@ class FmhaFwdSplitKVKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0, F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0, F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0, F_rk0 = self.F_tile.F_rk0,
...@@ -504,7 +514,7 @@ class FmhaFwdSplitKVKernel: ...@@ -504,7 +514,7 @@ class FmhaFwdSplitKVKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -559,6 +569,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -559,6 +569,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return { return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
} }
...@@ -576,6 +587,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d ...@@ -576,6 +587,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1),
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1),
} }
...@@ -604,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -604,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
......
...@@ -1126,7 +1126,7 @@ CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x) ...@@ -1126,7 +1126,7 @@ CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
template <> template <>
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
{ {
return __hneg(x); return -x;
}; };
template <typename T> template <typename T>
...@@ -1168,7 +1168,7 @@ CK_TILE_DEVICE double sin<double>(double x) ...@@ -1168,7 +1168,7 @@ CK_TILE_DEVICE double sin<double>(double x)
template <> template <>
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
{ {
return ::hsin(x); return __ocml_sin_f16(x);
}; };
template <typename T> template <typename T>
...@@ -1300,7 +1300,7 @@ CK_TILE_DEVICE double ceil<double>(double x) ...@@ -1300,7 +1300,7 @@ CK_TILE_DEVICE double ceil<double>(double x)
template <> template <>
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
{ {
return ::hceil(x); return __ocml_ceil_f16(x);
}; };
template <typename T> template <typename T>
...@@ -1342,7 +1342,7 @@ CK_TILE_DEVICE double floor<double>(double x) ...@@ -1342,7 +1342,7 @@ CK_TILE_DEVICE double floor<double>(double x)
template <> template <>
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
{ {
return ::hfloor(x); return __ocml_floor_f16(x);
}; };
template <typename T> template <typename T>
...@@ -1365,7 +1365,7 @@ CK_TILE_DEVICE T exp(T x) ...@@ -1365,7 +1365,7 @@ CK_TILE_DEVICE T exp(T x)
template <> template <>
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
{ {
return hexp(x); return __ocml_exp_f16(x);
}; };
template <> template <>
...@@ -1389,7 +1389,7 @@ CK_TILE_DEVICE T log(T x) ...@@ -1389,7 +1389,7 @@ CK_TILE_DEVICE T log(T x)
template <> template <>
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x) CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
{ {
return hlog(x); return __ocml_log_f16(x);
}; };
template <> template <>
......
...@@ -82,10 +82,10 @@ struct FmhaFwdKernel ...@@ -82,10 +82,10 @@ struct FmhaFwdKernel
if (kPadHeadDimV) n += "dv"; if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
...@@ -657,7 +657,7 @@ struct FmhaFwdKernel ...@@ -657,7 +657,7 @@ struct FmhaFwdKernel
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
...@@ -724,7 +724,7 @@ struct FmhaFwdKernel ...@@ -724,7 +724,7 @@ struct FmhaFwdKernel
[&]() { [&]() {
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kSubQKHeaddim>{});
else else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}); return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(), }(),
......
...@@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel ...@@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel
if (kPadHeadDimV) n += "dv"; if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
...@@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel ...@@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
...@@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel ...@@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel
[&]() { [&]() {
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kSubQKHeaddim>{});
else else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}); return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(), }(),
......
...@@ -39,7 +39,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -39,7 +39,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -42,7 +42,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -42,7 +42,8 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS ...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -43,7 +43,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -43,7 +43,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return 1; return 1;
} }
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2; return 2;
else else
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops); static_assert(1 <= k0_loops);
......
...@@ -41,7 +41,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -41,7 +41,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -41,7 +41,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -41,7 +41,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template<> struct template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on // clang-format on
...@@ -335,9 +338,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -335,9 +338,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kN0 = BlockFmhaShape::kN0; constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0; constexpr index_t kK0 = BlockFmhaShape::kK0;
constexpr index_t kK1 = BlockFmhaShape::kK1; constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{}; return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
......
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
namespace ck_tile { namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<... template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
...@@ -36,10 +50,12 @@ struct TileFmhaShape ...@@ -36,10 +50,12 @@ struct TileFmhaShape
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kK0BlockLength = static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile) // once (or repeately load Q as a whole tile)
static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment