Unverified Commit 81e00bce authored by Max Podkorytov's avatar Max Podkorytov
Browse files

re-add group-mode kernels

parent c657f72b
...@@ -136,7 +136,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -136,7 +136,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}} }}
""" """
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
...@@ -284,8 +284,6 @@ class FmhaFwdApiPool: ...@@ -284,8 +284,6 @@ class FmhaFwdApiPool:
inners=str() inners=str()
first_k = True first_k = True
for k, trait in enumerate(traits): for k, trait in enumerate(traits):
if trait.mode != "batch":
continue
if trait.dropout == "t": if trait.dropout == "t":
continue continue
if trait.lse == "t": if trait.lse == "t":
...@@ -501,8 +499,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e ...@@ -501,8 +499,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
if d == None: if d == None:
continue continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
for hdim_str, mode in itertools.product(d.keys(), ["batch"]):
tile = d[hdim_str] tile = d[hdim_str]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim): for pipeline in get_pipelines(dtype, hdim):
......
...@@ -849,7 +849,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -849,7 +849,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else // fmha_fwd_traits or fmha_splitkv_traits else // fmha_fwd_traits or fmha_splitkv_traits
{ {
// traits.is_group_mode = (mode == mode_enum::group); traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type; traits.mask_type = mask.type;
traits.bias_type = bias.type; traits.bias_type = bias.type;
// traits.has_lse = lse; // traits.has_lse = lse;
......
...@@ -782,7 +782,7 @@ struct fmha_fwd_traits ...@@ -782,7 +782,7 @@ struct fmha_fwd_traits
int hdim_q; int hdim_q;
int hdim_v; int hdim_v;
std::string data_type; std::string data_type;
// bool is_group_mode; bool is_group_mode;
bool is_v_rowmajor; bool is_v_rowmajor;
mask_enum mask_type; mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
......
...@@ -40,7 +40,7 @@ run_fp16_bf16_tests() { ...@@ -40,7 +40,7 @@ run_fp16_bf16_tests() {
fi fi
for prec in "bf16" ; do for prec in "bf16" ; do
for mode in 0 ; do for mode in 0 1; do
for perm in 0 ; do for perm in 0 ; do
for vlayout in "r" ; do for vlayout in "r" ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment