Unverified Commit 510ff45f authored by Max Podkorytov's avatar Max Podkorytov
Browse files

unhardcode score_mod and pass it as a cpp expression from codegen

parent 27a2a0a1
...@@ -84,8 +84,20 @@ using fmha_epilogue_{F_idx} = ...@@ -84,8 +84,20 @@ using fmha_epilogue_{F_idx} =
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>; {F_spad}, {F_dvpad}>>;
struct score_mod_def_{F_idx} {{
using TScore = typename fmha_pipeline_{F_idx}::SaccDataType;
CK_TILE_HOST_DEVICE TScore operator()(TScore s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) const {{
(void) s; (void) h; (void) b; (void) q_idx; (void) v_idx;
return {F_score_mod_expr};
}}
}};
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
...@@ -337,6 +349,7 @@ class FmhaFwdKernel: ...@@ -337,6 +349,7 @@ class FmhaFwdKernel:
F_mode : str # value from MODE_MAP F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline F_pipeline : FmhaFwdPipeline
F_score_mod_expr: str
mask_impl : str mask_impl : str
@property @property
...@@ -378,7 +391,8 @@ class FmhaFwdKernel: ...@@ -378,7 +391,8 @@ class FmhaFwdKernel:
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_score_mod_expr = self.F_score_mod_expr)
@property @property
def name(self) -> str: def name(self) -> str:
...@@ -433,7 +447,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -433,7 +447,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else: else:
return None return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future # support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
...@@ -502,6 +516,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -502,6 +516,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
F_mode=mode, F_mode=mode,
F_tile=tile, F_tile=tile,
F_pipeline=pipeline, F_pipeline=pipeline,
F_score_mod_expr=score_mod_expr,
mask_impl=mask_impl) mask_impl=mask_impl)
if kernel_filter != None: if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter): if not fnmatch.fnmatch(k.name, kernel_filter):
...@@ -524,15 +539,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: ...@@ -524,15 +539,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
for kernel in kernels: for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir) write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir) write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
with file_path.open('a') as f: with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
...@@ -30,7 +30,7 @@ handlers = dict( ...@@ -30,7 +30,7 @@ handlers = dict(
) )
assert 0 < len(handlers) assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
if output_dir is None: if output_dir is None:
output_dir = Path(__file__).parent output_dir = Path(__file__).parent
else: else:
...@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : ...@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.WRITE_BLOBS] handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl) handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr)
# list all the files that will be generated # list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
...@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter ...@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS] handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl) handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -106,9 +106,18 @@ if __name__ == "__main__": ...@@ -106,9 +106,18 @@ if __name__ == "__main__":
" 2: Only generate instance for Flash attention integration" " 2: Only generate instance for Flash attention integration"
) )
parser.add_argument(
"--score_mod_expr",
default="s",
# test with
# default="s + static_cast<decltype(s)>(q_idx - v_idx)"
required=False,
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
args = parser.parse_args() args = parser.parse_args()
api_list = args.direction.split(',') api_list = args.direction.split(',')
if args.list_blobs is not None: if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
else: else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace ck_tile { namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_> template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_>
struct FmhaFwdKernel struct FmhaFwdKernel
{ {
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
...@@ -1302,16 +1302,11 @@ struct FmhaFwdKernel ...@@ -1302,16 +1302,11 @@ struct FmhaFwdKernel
} }
}(); }();
auto score_mod_def = [](auto s, // may have state inside
ck_tile::index_t b, auto score_mod_def = ScoreModFunction_{};
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) h; (void) b;
return s + static_cast<decltype(s)>(q_idx - v_idx);
};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](auto s, auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](
typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx, ck_tile::index_t q_idx,
ck_tile::index_t v_idx) { ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx); return score_mod_def(s, b, h, q_idx, v_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment