Commit 39ad271b authored by danyao12's avatar danyao12
Browse files

codegen update

parent 39fc3d4b
...@@ -14,15 +14,11 @@ from codegen.cpp_symbol_map import * ...@@ -14,15 +14,11 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP = { BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
} }
BWD_DQDKDV_PIPELINE_ENUM_MAP = { BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
} }
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
...@@ -34,39 +30,41 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT ...@@ -34,39 +30,41 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP // G0&G2 -> GSdP
// G1&G3 -> GdKV // G1&G3 -> GdKV
// G4 -> GdQ // G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}, using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx}, fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx}, fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx}, fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx}, fmha_warp_tile1_{F_idx},
fmha_block_warps0_{F_idx}, fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx}, fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx}, fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx}, fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx}, fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>; fmha_warp_tile0_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad}, {F_skpad},
{F_dpad}, {F_dpad},
{F_dvpad}, {F_dvpad},
{F_bias}, {F_bias},
{F_dbias}, {F_dbias},
false, false,
{F_dropout}, false,
false, {F_occupancy}>;
{F_occupancy}>; using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
...@@ -86,55 +84,73 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< ...@@ -86,55 +84,73 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx}, fmha_bwd_shape_{F_idx},
{F_mode}, {F_mode},
{F_deterministic},
fmha_mask_{F_idx}, fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_bwd_trait_{F_idx}>; fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} = using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType, ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>; false,
false>>;
using fmha_bwd_dv_epilogue_{F_idx} = using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType, ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>; false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} = using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>, ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>,
fmha_bwd_pipeline_{F_idx}, fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx}, fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>; fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_pipeline_enum},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
#include <iostream> #include <iostream>
template<> template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush; std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)); return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}} }}
template<> template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a) void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{ {{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}} }}
template<> template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>() std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{ {{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
...@@ -146,14 +162,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" ...@@ -146,14 +162,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API=""" FMHA_BWD_API="""
#include <iostream> #include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_> template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s, return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }} [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
); );
}} }}
...@@ -173,38 +190,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -173,38 +190,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}} }}
""" """
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a); using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
""" """
@dataclass @dataclass
class FmhaBwdDQDKDVApiTrait: class FmhaBwdDQDKDVApiTrait:
pipeline : str pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls # sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str hdim : str
dtype : str # data type dtype : str # data type
mode : str # value from MODE_MAP mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size) bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen bn0 : int # tile size along k seqlen
bhdq : int # q head_dim bhdq : int # q head_dim
bhdv : int # v head_dim bhdv : int # v head_dim
mask : str mask : str
bias : str bias : str
dbias : str dbias : str
dropout : str dropout : str
spad : str spad : str
skpad : str skpad : str
dpad : str dpad : str
dvpad : str dvpad : str
deterministic : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
def scheck(self, spad1 : str) -> str: def scheck(self, spad1 : str) -> str:
if self.mode == 'group': if self.mode == 'group':
...@@ -212,9 +227,9 @@ class FmhaBwdDQDKDVApiTrait: ...@@ -212,9 +227,9 @@ class FmhaBwdDQDKDVApiTrait:
elif self.spad == 't' and spad1 == 't': elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0' return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't': elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f' else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize return f'a.seqlen_q % 64 == 0'
@property @property
def skcheck(self) -> str: def skcheck(self) -> str:
...@@ -256,16 +271,21 @@ class FmhaBwdApiPool: ...@@ -256,16 +271,21 @@ class FmhaBwdApiPool:
per_hdim_case=str() per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim] traits=self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim)
inners=str() inners=str()
for k, trait in enumerate(traits): for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if' if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]: for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
if (spad1 == "t" and trait.spad == "f" and hdim_int <= 64):
continue continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -300,74 +320,82 @@ class FmhaBwdDQDKDVTileSize: ...@@ -300,74 +320,82 @@ class FmhaBwdDQDKDVTileSize:
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4 F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size) F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
F_wn : int # warp size along n F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_wk : int # warp size along k F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
F_wm1 : int # warp size along m in gemm1/gemm3
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy F_occupancy : int # occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
@dataclass @dataclass
class FmhaBwdDQDKDVKernel: class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim F_hdim : int # hdim
F_dtype : str # data type F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false F_spad : str # true/false
F_skpad : str # F_skpad : str #
F_dpad : str # F_dpad : str #
F_dvpad : str # F_dvpad : str #
F_bias : str # F_bias : str #
F_dbias : str # F_dbias : str #
F_dropout : str # F_dropout : str #
F_mask : str # value from MASK_MAP F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP F_mode : str # value from MODE_MAP
F_pipeline : str F_deterministic : str #
mask_impl : str F_pipeline : str #
mask_impl : str #
@property @property
def template(self) -> str: def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \ return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2, F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3, F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4, F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq, F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv, F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0, F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0, F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0, F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1, F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1, F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1, F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2, F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2, F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2, F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm, F_wm0 = self.F_tile.F_wm0,
F_wn = self.F_tile.F_wn, F_wn0 = self.F_tile.F_wn0,
F_wk = self.F_tile.F_wk, F_wk0 = self.F_tile.F_wk0,
F_spad = BOOL_MAP[self.F_spad], F_wm1 = self.F_tile.F_wm1,
F_skpad = BOOL_MAP[self.F_skpad], F_wn1 = self.F_tile.F_wn1,
F_dpad = BOOL_MAP[self.F_dpad], F_wk1 = self.F_tile.F_wk1,
F_dvpad = BOOL_MAP[self.F_dvpad], F_spad = BOOL_MAP[self.F_spad],
F_bias = BIAS_MAP[self.F_bias], F_skpad = BOOL_MAP[self.F_skpad],
F_dbias = BOOL_MAP[self.F_dbias], F_dpad = BOOL_MAP[self.F_dpad],
F_dropout = BOOL_MAP[self.F_dropout], F_dvpad = BOOL_MAP[self.F_dvpad],
F_occupancy = self.F_tile.F_occupancy, F_bias = BIAS_MAP[self.F_bias],
F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_dbias = BOOL_MAP[self.F_dbias],
F_mode = MODE_MAP[self.F_mode], F_dropout = DROPOUT_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property @property
def name(self) -> str: def name(self) -> str:
...@@ -388,7 +416,8 @@ class FmhaBwdDQDKDVKernel: ...@@ -388,7 +416,8 @@ class FmhaBwdDQDKDVKernel:
if self.F_mask == 's_mask': n += f'_mask' if self.F_mask == 's_mask': n += f'_mask'
else: else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout == 't' : n += '_dropout' if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_deterministic == 't' : n += '_deterministic'
return n return n
@property @property
...@@ -411,19 +440,23 @@ class FmhaBwdDQDKDVKernel: ...@@ -411,19 +440,23 @@ class FmhaBwdDQDKDVKernel:
spad=self.F_spad, spad=self.F_spad,
skpad=self.F_skpad, skpad=self.F_skpad,
dpad=self.F_dpad, dpad=self.F_dpad,
dvpad=self.F_dvpad) dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size & pipeline. # this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), # '32' : [FmhaBwdDQDKDVTileSize( 64, 64, 32, 64, 32, 64, 64, 32, 32, 1, 2, 1, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, 1),
"qs_ks_vr_dos"], # "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1),
"qs_ks_vr_dos"], "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), # '128' : [FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 32, 32, 16, 32, 32, 16, 1),
"ks_vr"] # "kr_ktr_vr"],
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr"]
} }
else: else:
return None return None
...@@ -438,7 +471,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -438,7 +471,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0] tile = d[hdim_str][0]
ppl = d[hdim_str][1] ppl = d[hdim_str][1]
hdim = int(hdim_str) hdim = int(hdim_str)
...@@ -446,10 +479,12 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -446,10 +479,12 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue continue
if ((bias == "no" or bias == "alibi") and dbias == "t"): if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue continue
if ((hdim <= 128 and ("wg16" in dropout)) or (hdim == 256 and ("wg32" in dropout))):
continue
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl) F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != None: if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter): if not fnmatch.fnmatch(k.name, kernel_filter):
continue continue
...@@ -466,53 +501,55 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -466,53 +501,55 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, using fmha_bwd_dot_do_o_trait_{F_idx} =
{F_dvpad}, ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
{F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256, /* BlockSize = */ 64,
{F_hdim}, {F_hdim},
{F_mode}, {F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>; fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< using fmha_bwd_dot_do_o_{F_idx} =
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} = using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>, ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdQTilePartitioner</* BlockSize = */ 64>,
fmha_bwd_dot_do_o_{F_idx}>; fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream> #include <iostream>
template<> template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush; std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)); return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}} }}
template<> template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a) void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}} }}
template<> template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>() std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{ {{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
...@@ -582,12 +619,171 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: ...@@ -582,12 +619,171 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen return gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_hdim}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_bwd_convert_dq_shape_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
fmha_bwd_convert_dq_shape_{F_idx},
fmha_bwd_convert_dq_trait_{F_idx},
{F_mode},
{F_deterministic}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>,
fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_rm : int # number of warps along k seqlen (block warps) in gemm4
F_rn : int # number of warps along q seqlen (block warps) in gemm4
F_rk : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m in gemm4
F_wn : int # warp size along n in gemm4
F_wk : int # warp size along k in gemm4
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_rm = self.F_rm,
F_rn = self.F_rn,
F_rk = self.F_rk,
F_wm = self.F_wm,
F_wn = self.F_wn,
F_wk = self.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dpad == 't' : n += 'd'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_r{self.F_rm}x{self.F_rn}x{self.F_rk}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for dtype in DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
tile = d[hdim_str][0]
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_rm=tile.F_rm2, F_rn=tile.F_rn2, F_rk=tile.F_rk2, F_wm=tile.F_wm0, F_wn=tile.F_wn0, F_wk=tile.F_wk0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k)
return gen
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template) (autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template) (autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
...@@ -595,6 +791,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_ ...@@ -595,6 +791,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
kernels = get_bwd_dot_do_o_blobs() kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels: for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir) write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
...@@ -603,6 +802,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_ ...@@ -603,6 +802,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f: with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs() kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment