Commit 86923c19 authored by Jim's avatar Jim
Browse files

update

parent 8b745f2c
...@@ -127,8 +127,41 @@ BOOL_MAP = { ...@@ -127,8 +127,41 @@ BOOL_MAP = {
"f" : "false" "f" : "false"
} }
BWD_V3_HDIM_MAP = {
"64": "64",
"128": "128"
}
BF16_CVT_MAP = { BF16_CVT_MAP = {
0 : "rtne", 0 : "rtne",
1 : "rtna", 1 : "rtna",
2 : "rtz", 2 : "rtz",
}
BWD_V3_MASK_MAP = {
"t": "((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0)))",
"f": "(t.mask_type == mask_enum::no_mask)"
}
BWD_V3_ATOMIC32_MAP = {
"t": "((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/))",
"f": "(t.is_v3_atomic_fp32 == false)"
}
BWD_V3_HDIM_CASE_MAP = {
0: "(a.hdim_q == 128)",
1: "(a.hdim_q == 64)",
2: "((a.hdim_q > 64) && (a.hdim_q < 128))"
}
BWD_V3_HDIM_CASE_CHECK_MAP = {
0: 128,
1: 64,
2: 128
}
BWD_V3_PADDING_CHECK_MAP = {
0: "false",
1: "false",
2: "true"
} }
\ No newline at end of file
...@@ -162,9 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>() ...@@ -162,9 +162,8 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE=""" FMHA_BWD_V3_TEMPLATE="""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}"; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
""" """
...@@ -589,7 +588,7 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -589,7 +588,7 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1; float r = -1;
if ((t.uses_bwd_v3 == true){{ if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) && if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) &&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && (a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
...@@ -623,38 +622,40 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) ...@@ -623,38 +622,40 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
}} }}
""" """
FMHA_V3_DISPATCH=""" FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
if (t.mask_type == mask_enum::no_mask){{ using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; return r;"""
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
return r; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
}} r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
else if (t.is_v3_atomic_fp32 == false){{ return r;"""
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; FMHA_BWD_V3_PER_DTYPE_CASE=""" {F_if} (t.data_type.compare(\"{F_dtype}\") == 0) {{
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); {per_bf16_cvt_dispatch}
return r;
}}
}} }}
else if ((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ """
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; FMHA_BWD_V3_PER_BF16_CVT_CASE=""" {F_if} (t.how_v3_bf16_cvt == {F_bf16_cvt}) {{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; {per_mask_dispatch}
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if (t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}} }}
}} """
}}
FMHA_BWD_V3_PER_MASK_CASE=""" {F_if} {F_mask_expression}{{
{per_atomic_dispatch}
}}
"""
FMHA_BWD_V3_PER_ATOMIC_CASE=""" {F_if} {F_atomic_expression}{{
{per_hdim_dispatch}
}}
"""
FMHA_BWD_V3_PER_HDIM_CASE=""" {F_if} {F_hdim_expression}{{
{inner_dispatch}
}}
""" """
@dataclass @dataclass
...@@ -714,11 +715,18 @@ class FmhaBwdV3DQDKDVApiTrait: ...@@ -714,11 +715,18 @@ class FmhaBwdV3DQDKDVApiTrait:
is_causal : str is_causal : str
is_atomic : str is_atomic : str
bf16_cvt : int bf16_cvt : int
hd_pad : str is_hdpad : str
def remap_hdim(self):
hdim_int = int(self.hdim)
if hdim_int > 64:
self.hdim = 128
hdim_int = (hdim_int + 64 - 1) / 64 * 64
class FmhaBwdApiPool: class FmhaBwdApiPool:
def __init__(self, mask_impl): def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict() self.dq_dk_dv_pool = dict()
self.dq_dk_dv_v3_pool = dict()
self.mask_impl = mask_impl self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
...@@ -730,6 +738,15 @@ class FmhaBwdApiPool: ...@@ -730,6 +738,15 @@ class FmhaBwdApiPool:
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
def register_dq_dk_dv_v3_traits(self, trait : FmhaBwdV3DQDKDVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_v3_pool.keys():
self.dq_dk_dv_v3_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_v3_pool[trait.dtype].keys():
self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property @property
def api(self) -> str: def api(self) -> str:
per_dtypes=str() per_dtypes=str()
...@@ -758,28 +775,51 @@ class FmhaBwdApiPool: ...@@ -758,28 +775,51 @@ class FmhaBwdApiPool:
# empty string we add some ignore to suppress warning in api # empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;' per_dtypes += ' (void)t ; (void)s ; (void)a;'
v3_code = str()
gen_template = str() gen_template = str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): for j, hdim in enumerate(BWD_V3_HDIM_MAP.keys()):
traits = self.dq_dk_dv_pool[dtype][hdim] traits = self.dq_dk_dv_v3_pool[dtype][hdim]
hdim_int = int(hdim) hdim = int(hdim)
hdim_int = (hdim_int + 64 - 1) / 64 * 64
Ts_qo = 32 if hdim == 64 else 16 Ts_qo = 32 if hdim == 64 else 16
for k, trait in enumerate(traits): for k, trait in enumerate(traits):
padding = "t" if hdim_int % 64 == 0 else "f" hdim_name = "_hd64" if hdim == 64 else ""
padding_suffix = "_hdp" if padding == "t" else "" dtype_name = "_{}".format(dtype)
v3_code = v3_code + FMHA_V3_DISPATCH.format(F_hdim=hdim_int, F_dtype=BWD_DTYPE_MAP[dtype], F_padding=BOOL_MAP[padding], causal_name = "_causal" if trait.is_causal == "t" else ""
F_is_atomic32=BOOL_MAP[trait.is_atomic], F_how_v3_bf16_cvt=trait.bf16_cvt, F_padding_suffix=padding_suffix) atomic_name = "_a32" if trait.is_atomic == "t" else "_a16"
bf16_cvt_name = "_{}".format(BF16_CVT_MAP[trait.bf16_cvt])
hdim_name = "hd64" if hdim_int == 64 else "" gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic],
dtype_name = dtype F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[trait.is_hdpad], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
causal_name = "causal" if trait.is_causal == "t" else ""
atomic_name = "a32" if trait.is_atomic == "t" else "a16"
bf16_cvt_name = BF16_CVT_MAP[trait.bf16_cvt]
gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim_int, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic],
F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[padding], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name) F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name)
v3_code = str()
for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
per_bf16 = str()
for j, bf16_cvt in enumerate([0, 1, 2]):
per_mask = str()
if (dtype == "fp16") and (bf16_cvt in [1, 2]):
continue
for k, is_causal in enumerate(["t", "f"]):
per_atomic = str()
for l, is_atomic in enumerate(["t", "f"]):
per_hdim = str()
for m, hdim in enumerate(BWD_V3_HDIM_CASE_CHECK_MAP.values()):
if_m = 'if' if m == 0 else 'else if'
inners = str()
if is_atomic == "t":
inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt, F_padding=BWD_V3_PADDING_CHECK_MAP[m])
else:
inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt, F_padding=BWD_V3_PADDING_CHECK_MAP[m])
per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners)
if_l = 'if' if l == 0 else 'else if'
per_atomic = per_atomic + FMHA_BWD_V3_PER_ATOMIC_CASE.format(F_if=if_l, F_atomic_expression=BWD_V3_ATOMIC32_MAP[is_atomic], per_hdim_dispatch=per_hdim)
if_k = 'if' if k == 0 else 'else if'
per_mask = per_mask + FMHA_BWD_V3_PER_MASK_CASE.format(F_if=if_k, F_mask_expression=BWD_V3_MASK_MAP[is_causal], per_atomic_dispatch=per_atomic)
if_j = 'if' if j == 0 else 'else if'
per_bf16 = per_bf16 + FMHA_BWD_V3_PER_BF16_CVT_CASE.format(F_if=if_j, F_bf16_cvt=bf16_cvt, per_mask_dispatch=per_mask)
if_i = 'if' if i == 0 else 'else if'
v3_code = v3_code + FMHA_BWD_V3_PER_DTYPE_CASE.format(F_if=if_i, F_dtype=dtype, per_bf16_cvt_dispatch=per_bf16)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code)
# GEMM0: Q@K=S^T # GEMM0: Q@K=S^T
...@@ -931,6 +971,40 @@ class FmhaBwdDQDKDVKernel: ...@@ -931,6 +971,40 @@ class FmhaBwdDQDKDVKernel:
dvpad=self.F_dvpad, dvpad=self.F_dvpad,
deterministic=self.F_deterministic deterministic=self.F_deterministic
) )
@dataclass
class FmhaBwdV3DQDKDVKernel:
F_hdim : int # hdim
F_dtype : str # data type
F_is_causal : str
F_is_atomic : str
F_bf16_cvt : int
F_is_hdpad : str
# @property
# def gen_bwd_v3_template(self) -> str:
# hdim_template = 64 if self.F_hdim == 64 else 128
# Ts_qo = 32 if hdim_template == 64 else 16
# padding = "t" if self.F_hdim % 64 == 0 else "f"
# padding_suffix = "_hdp" if padding == "t" else ""
# hdim_name = "_hd64" if hdim_template == 64 else ""
# dtype_name = "_{}".format(self.F_dtype)
# causal_name = "_causal" if self.F_is_causal == "t" else ""
# atomic_name = "_a32" if self.F_is_atomic == "t" else "_a16"
# bf16_cvt_name = "_{}".format(BF16_CVT_MAP[self.F_bf16_cvt])
# return FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim_template, F_dtype=BWD_DTYPE_MAP[self.F_dtype], F_is_atomic=BOOL_MAP[self.F_is_atomic],
# F_is_causal=BOOL_MAP[self.F_is_causal], F_bf16_cvt=self.F_bf16_cvt, F_hdpad=BOOL_MAP[padding], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
# F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name)
def v3_api_trait(self) -> FmhaBwdV3DQDKDVApiTrait:
return FmhaBwdV3DQDKDVApiTrait(hdim=str(self.F_hdim),
dtype=self.F_dtype,
is_causal=self.F_is_causal,
is_atomic=self.F_is_atomic,
bf16_cvt=self.F_bf16_cvt,
is_hdpad=self.F_is_hdpad
)
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size & pipeline. # this is current supported tile size & pipeline.
...@@ -995,6 +1069,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -995,6 +1069,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool.register_dq_dk_dv_traits(k.api_trait()) api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k) gen.append(k)
for hdim_str, is_causal, is_atomic, bf16_cvt, is_hdpad in itertools.product(d.keys(), ["t", "f"], ["t", "f"], [0, 1, 2], ["t", "f"]):
hdim = int(hdim_str)
k = FmhaBwdV3DQDKDVKernel(F_hdim=hdim, F_dtype=dtype, F_is_causal=is_causal, F_is_atomic=is_atomic, F_bf16_cvt=bf16_cvt, F_is_hdpad=is_hdpad)
if receipt == 3:
cond = (dtype == 'fp16') and (bf16_cvt in [1, 2])
if cond:
# print(dtype, bf16_cvt)
continue
api_pool.register_dq_dk_dv_v3_traits(k.v3_api_trait())
# gen.append(k)
return (api_pool, gen) return (api_pool, gen)
FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
......
This diff is collapsed.
...@@ -99,7 +99,7 @@ if __name__ == "__main__": ...@@ -99,7 +99,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-r", "-r",
"--receipt", "--receipt",
default=0, default=3,
required=False, required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \ " 1: generate more instance to cover all hdim\n" + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment