Unverified Commit 510ff45f authored by Max Podkorytov's avatar Max Podkorytov
Browse files

unhardcode score_mod and pass it as a cpp expression from codegen

parent 27a2a0a1
......@@ -84,8 +84,20 @@ using fmha_epilogue_{F_idx} =
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
struct score_mod_def_{F_idx} {{
using TScore = typename fmha_pipeline_{F_idx}::SaccDataType;
CK_TILE_HOST_DEVICE TScore operator()(TScore s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) const {{
(void) s; (void) h; (void) b; (void) q_idx; (void) v_idx;
return {F_score_mod_expr};
}}
}};
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
......@@ -337,6 +349,7 @@ class FmhaFwdKernel:
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
F_score_mod_expr: str
mask_impl : str
@property
......@@ -378,7 +391,8 @@ class FmhaFwdKernel:
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_score_mod_expr = self.F_score_mod_expr)
@property
def name(self) -> str:
......@@ -433,7 +447,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
......@@ -502,6 +516,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
F_score_mod_expr=score_mod_expr,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
......@@ -524,15 +539,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
......@@ -30,7 +30,7 @@ handlers = dict(
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
......@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
for api in api_list:
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl)
handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
assert output_file is not None
file_path = Path(output_file)
......@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl)
handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -106,9 +106,18 @@ if __name__ == "__main__":
" 2: Only generate instance for Flash attention integration"
)
parser.add_argument(
"--score_mod_expr",
default="s",
# test with
# default="s + static_cast<decltype(s)>(q_idx - v_idx)"
required=False,
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
args = parser.parse_args()
api_list = args.direction.split(',')
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
......@@ -20,7 +20,7 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_>
struct FmhaFwdKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
......@@ -1302,16 +1302,11 @@ struct FmhaFwdKernel
}
}();
auto score_mod_def = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) h; (void) b;
return s + static_cast<decltype(s)>(q_idx - v_idx);
};
// may have state inside
auto score_mod_def = ScoreModFunction_{};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](auto s,
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](
typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment