Commit a4522ae3 authored by illsilin's avatar illsilin
Browse files

sync from public repo

parents 1f127242 e0594d08
...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx950 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
...@@ -21,6 +21,14 @@ DTYPE_BITS = { ...@@ -21,6 +21,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
TILE_PARTITIONER_MAP = { TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
...@@ -35,14 +43,13 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT ...@@ -35,14 +43,13 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_FWD_KERNEL_BODY=""" FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx}, using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx}, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}, fmha_warp_tile_{F_idx},
fmha_block_warps_{F_idx}, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}, fmha_warp_tile_{F_idx},
{F_vlayout}>; {F_vlayout}>;
...@@ -88,7 +95,7 @@ using fmha_kernel_{F_idx} = ...@@ -88,7 +95,7 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx}, fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -126,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -126,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
}} }}
""" """
...@@ -143,7 +150,7 @@ class FmhaFwdApiTrait: ...@@ -143,7 +150,7 @@ class FmhaFwdApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -157,7 +164,7 @@ class FmhaFwdApiTrait: ...@@ -157,7 +164,7 @@ class FmhaFwdApiTrait:
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property @property
...@@ -189,8 +196,9 @@ class FmhaFwdApiTrait: ...@@ -189,8 +196,9 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -200,8 +208,9 @@ class FmhaFwdApiTrait: ...@@ -200,8 +208,9 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -272,7 +281,7 @@ class FmhaFwdApiPool: ...@@ -272,7 +281,7 @@ class FmhaFwdApiPool:
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -290,19 +299,22 @@ class FmhaFwdTileSize: ...@@ -290,19 +299,22 @@ class FmhaFwdTileSize:
F_bk0 : int # tile size along qk gemm unroll F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll F_bk1 : int # tile size along kv gemm unroll
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm : int # number of warps along q seqlen (block warps) F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn : int # number of warps along k seqlen(not used) F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk : int # number of warps along gemm-k(not used) F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm : int # warp size along m (warp size) F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n F_wn : int # warp size along n
F_wk : int # warp size along k F_wk : int # warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass @dataclass
class FmhaFwdKernel: class FmhaFwdKernel:
...@@ -333,10 +345,13 @@ class FmhaFwdKernel: ...@@ -333,10 +345,13 @@ class FmhaFwdKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm = self.F_tile.F_rm, F_rm0 = self.F_tile.F_rm0,
F_rn = self.F_tile.F_rn, F_rn0 = self.F_tile.F_rn0,
F_rk = self.F_tile.F_rk, F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn, F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk, F_wk = self.F_tile.F_wk,
...@@ -377,7 +392,7 @@ class FmhaFwdKernel: ...@@ -377,7 +392,7 @@ class FmhaFwdKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -394,16 +409,17 @@ class FmhaFwdKernel: ...@@ -394,16 +409,17 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -505,4 +521,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -505,4 +521,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -29,6 +29,14 @@ DTYPE_BITS = { ...@@ -29,6 +29,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
...@@ -41,14 +49,13 @@ using fmha_mask_{F_idx} = {F_mask}; ...@@ -41,14 +49,13 @@ using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
fmha_block_warps, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile, fmha_warp_tile,
fmha_block_warps, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile, fmha_warp_tile,
{F_vlayout}>; {F_vlayout}>;
...@@ -104,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -104,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>; {F_dvpad}>;
...@@ -162,10 +169,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< ...@@ -162,10 +169,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
fmha_pipeline_problem>; fmha_pipeline_problem>;
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue = using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>; false, false>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>, ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
...@@ -191,7 +200,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m ...@@ -191,7 +200,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
template<> template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 16) {{ if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a);
}} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); kernel_runner<4>::run(s, a);
}} else if (a.num_splits <= 32) {{ }} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a); kernel_runner<5>::run(s, a);
...@@ -238,8 +249,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -238,8 +249,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
...@@ -257,7 +268,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -257,7 +268,7 @@ class FmhaFwdSplitKVApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -267,11 +278,11 @@ class FmhaFwdSplitKVApiTrait: ...@@ -267,11 +278,11 @@ class FmhaFwdSplitKVApiTrait:
skpad : str skpad : str
dpad : str dpad : str
dvpad : str dvpad : str
pagedkv : str pagedkv : str
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}' f'{self.dvpad}-{self.pagedkv}'
...@@ -304,8 +315,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -304,8 +315,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -315,8 +327,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -315,8 +327,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -411,7 +424,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -411,7 +424,7 @@ class FmhaFwdSplitKVApiPool:
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -455,10 +468,13 @@ class FmhaFwdSplitKVKernel: ...@@ -455,10 +468,13 @@ class FmhaFwdSplitKVKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm = self.F_tile.F_rm, F_rm0 = self.F_tile.F_rm0,
F_rn = self.F_tile.F_rn, F_rn0 = self.F_tile.F_rn0,
F_rk = self.F_tile.F_rk, F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn, F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk, F_wk = self.F_tile.F_wk,
...@@ -498,7 +514,7 @@ class FmhaFwdSplitKVKernel: ...@@ -498,7 +514,7 @@ class FmhaFwdSplitKVKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -551,16 +567,17 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -551,16 +567,17 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -568,16 +585,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -568,16 +585,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1),
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
} }
else: else:
return None return None
...@@ -598,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -598,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
...@@ -737,4 +755,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -737,4 +755,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
#endif #endif
struct static const auto get_lengths = [](bool permute,
{ ck_tile::index_t b /*batch*/,
auto operator()(bool permute, ck_tile::index_t h /*nhead*/,
ck_tile::index_t b /*batch*/, ck_tile::index_t s /*seqlen*/,
ck_tile::index_t h /*nhead*/, ck_tile::index_t d /*hdim*/) {
ck_tile::index_t s /*seqlen*/, if(permute)
ck_tile::index_t d /*hdim*/) return std::array<ck_tile::index_t, 4>{b, h, s, d};
{ else
if(permute) return std::array<ck_tile::index_t, 4>{b, s, h, d};
return std::array<ck_tile::index_t, 4>{b, h, s, d}; };
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r"); bool is_v_rowmajor = vlayout == std::string("r");
...@@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q} ? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v) nhead,
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q]
...@@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o_acc = (hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer(); args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer();
......
...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter ...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS] handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl) handler(file_path, kernel_filter, receipt, mask_impl)
......
# not using add_example_executable() to add this target, since we don't want this to have set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
# to be included in "make all/install/check" set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
\ No newline at end of file set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
# Layernorm2D forward # Layernorm2D forward
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation.
# Implementation and feature support
## welford online algorithm
We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`.
## mean/variance save
In training case the mean/variance need to store out (TBD, not supported yet)
## prenorm/postnorm
![](misc/pnorm.png)
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
## smooth-quant/dynamic-quant
we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
![](misc/dquant.png)
```
# assume output int8, hidden_states is [m, n] shape and in fp16/bf16
# [m, 1]
per_token_amax, _ = torch.max(
input=torch.abs(hidden_states),
dim=-1,
keepdim=True
)
per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0
# quant hidden_states
hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8)
return hidden_states, per_token_scale
# hidden_states now is int8 will feed to next layer as intput
# per_token_scale will be used as dequant factor later layer
```
## build ## build
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j make tile_example_layernorm2d_fwd -j
``` ```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd` This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
...@@ -16,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd` ...@@ -16,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
``` ```
args: args:
-m m dimension (default:3328) -m m dimension (default:3328)
-n m dimension (default:4096) -n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5) -e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1) -v cpu validation or not (default:1)
-prec precision (default:fp16) -kname print kernel name or not (default:1)
``` -prec_i input precision (default:fp16)
\ No newline at end of file -prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
```
## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
```
# some case
# standard fp16 layernorm 2d, m=10. n=1024
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
```
This diff is collapsed.
...@@ -8,23 +8,57 @@ ...@@ -8,23 +8,57 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
struct layernorm2d_fwd_traits template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig;
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_>
{
using XDataType = ck_tile::half_t;
using YDataType = OutType;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
};
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_>
{ {
std::string data_type; using XDataType = ck_tile::bf16_t;
using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
struct layernorm2d_fwd_args // runtime args
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{ {
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
}; };
// host API // This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
std::string prec_i; // input precision
std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; //
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
\ No newline at end of file
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
set(CMAKE_BUILD_TYPE Debug) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp)
\ No newline at end of file
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstring> #include <cstring>
...@@ -11,51 +9,48 @@ ...@@ -11,51 +9,48 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
auto create_args(int argc, char* argv[]) #include "ck_tile/ops/epilogue.hpp"
{ #include "ck_tile/ops/gemm.hpp"
ck_tile::ArgParser arg_parser; #include "ck_tile/host.hpp"
arg_parser.insert("b", "1", "batch size") #include "gemm_basic.hpp"
.insert("m", "1024", "m dimension")
.insert("n", "2048", "n dimension")
.insert("k", "64", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename LayoutA, template <typename ALayout, typename BLayout, typename CLayout>
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kPadC = true;
constexpr bool kTilePermute = false; constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
// The rank and permutation will also be generate out by the CodeGen part. constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t kOutputRank = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output // Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout. // layout.
constexpr bool CShuffleEpilogue = constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>; std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -71,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -71,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner::kN>>, TilePartitioner::kN>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
args.p_c, args.p_c,
args.epsilon,
args.M, args.M,
args.N, args.N,
args.K, args.K,
...@@ -89,295 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -89,295 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel( if(s.log_level_ > 0)
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
template <typename DataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
if(data_type != DataTypeTraits<DataType>::name)
{ {
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got " std::cout << "Launching kernel with args:"
<< data_type << std::endl; << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
return -1; // Or handle the error appropriately << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
} }
float epsilon = arg_parser.get_float("e"); float ave_time = ck_tile::launch_kernel(
ck_tile::index_t batch_size = arg_parser.get_int("b"); s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args;
args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon;
args.kbatch = batch_size;
args.M = M;
args.N = N;
args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if(stride_a != 0)
{
args.stride_A = stride_a;
}
else
{
args.stride_A = [&]() {
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return K;
}
}();
}
if(stride_b != 0)
{
args.stride_B = stride_b;
}
else
{
args.stride_B = [&]() {
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::RowMajor>)
{
return N;
}
else
{
return K;
}
}();
}
if(stride_c != 0)
{
args.stride_C = stride_c;
}
else
{
args.stride_C = [&]() {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return N;
}
}();
}
float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>(
args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< " is: \n";
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush;
return ave_time; return ave_time;
} }
int main(int argc, char* argv[]) #include "run_gemm_example.inc"
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor;
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor;
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify
std::vector<int> a_dimensions =
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> b_dimensions =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> c_dimensions =
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenGemmTraits = ck_tile::
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout,
CodegenPipelineProblem,
CodegenGemmPipeline,
CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser);
c_buf.FromDevice(c_host_dev.data());
bool pass_cpu = true;
if(arg_parser.get_int("v") == 1)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail")
<< std::flush;
}
bool pass_gpu = true; int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
if(arg_parser.get_int("v") == 2)
{
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0)
{
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_a = M;
}
else
{
stride_a = K;
}
}
if(stride_b == 0)
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b = N;
}
else
{
stride_b = K;
}
}
if(stride_c == 0)
{
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_c = M;
}
else
{
stride_c = N;
}
}
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data());
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush;
}
std::cout << std::endl << std::flush;
return !pass_gpu;
}
...@@ -4,12 +4,10 @@ ...@@ -4,12 +4,10 @@
#pragma once #pragma once
#include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
...@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
using AccDataType = float; using AccDataType = float;
using CDataType = ck_tile::half_t; // type convert using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
...@@ -58,7 +56,6 @@ struct gemm_basic_args ...@@ -58,7 +56,6 @@ struct gemm_basic_args
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
void* p_c; void* p_c;
float epsilon;
ck_tile::index_t kbatch; ck_tile::index_t kbatch;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
...@@ -68,5 +65,28 @@ struct gemm_basic_args ...@@ -68,5 +65,28 @@ struct gemm_basic_args
ck_tile::index_t stride_C; ck_tile::index_t stride_C;
}; };
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API // host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later.
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
constexpr int kBlockPerCu = 1;
// ===============================================
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(has_hot_loop)
{
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return ave_time;
}
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
This diff is collapsed.
set(EXAMPLE_REDUCE "tile_example_reduce")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_REDUCE}")
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment