Unverified Commit 81e00bce authored by Max Podkorytov's avatar Max Podkorytov
Browse files

re-add group-mode kernels

parent c657f72b
......@@ -136,7 +136,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
......@@ -284,8 +284,6 @@ class FmhaFwdApiPool:
inners=str()
first_k = True
for k, trait in enumerate(traits):
if trait.mode != "batch":
continue
if trait.dropout == "t":
continue
if trait.lse == "t":
......@@ -501,8 +499,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
for hdim_str, mode in itertools.product(d.keys(), ["batch"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
......
......@@ -849,7 +849,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else // fmha_fwd_traits or fmha_splitkv_traits
{
// traits.is_group_mode = (mode == mode_enum::group);
traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
// traits.has_lse = lse;
......
......@@ -782,7 +782,7 @@ struct fmha_fwd_traits
int hdim_q;
int hdim_v;
std::string data_type;
// bool is_group_mode;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
......
......@@ -40,7 +40,7 @@ run_fp16_bf16_tests() {
fi
for prec in "bf16" ; do
for mode in 0 ; do
for mode in 0 1; do
for perm in 0 ; do
for vlayout in "r" ; do
for hdim in 32 64 128 256 ; do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment