Unverified Commit bca939ce authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Add pre_softmax fnctor (#1852)



* Add pre_softmax fnctor

* remove stray define:wq

* Move op out of pipeline, adds it to refnc

---------
Co-authored-by: default avatarroot <root@splinter-126-wr-d1.aus.dcgpu>
Co-authored-by: default avatarMax Podkorytov <4273004+tenpercent@users.noreply.github.com>
parent 81e00bce
...@@ -9,6 +9,9 @@ endif() ...@@ -9,6 +9,9 @@ endif()
variable_watch(FMHA_SCORE_MOD_F) variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]]) set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
variable_watch(FMHA_PRE_SOFTMAX_F)
set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
foreach(api ${FMHA_FWD_ENABLE_APIS}) foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
...@@ -42,6 +45,7 @@ add_custom_command( ...@@ -42,6 +45,7 @@ add_custom_command(
--api ${FMHA_FWD_APIS} --api ${FMHA_FWD_APIS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
"--score_mod_expr=${FMHA_SCORE_MOD_F}" "--score_mod_expr=${FMHA_SCORE_MOD_F}"
"--pre_softmax_expr=${FMHA_PRE_SOFTMAX_F}"
VERBATIM VERBATIM
) )
...@@ -88,6 +92,7 @@ endif() ...@@ -88,6 +92,7 @@ endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}") list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}")
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_PRE_SOFTMAX_F=${FMHA_PRE_SOFTMAX_F}")
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
......
...@@ -96,8 +96,17 @@ struct score_mod_def_{F_idx} {{ ...@@ -96,8 +96,17 @@ struct score_mod_def_{F_idx} {{
}} }}
}}; }};
struct pre_softmax_def_{F_idx} {{
using TScore = typename fmha_pipeline_{F_idx}::SaccDataType;
CK_TILE_HOST_DEVICE TScore operator()(TScore s
) const {{
(void) s;
return {F_pre_softmax_expr};
}}
}};
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}>; ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}, pre_softmax_def_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
...@@ -348,6 +357,7 @@ class FmhaFwdKernel: ...@@ -348,6 +357,7 @@ class FmhaFwdKernel:
F_tile : FmhaFwdTileSize F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline F_pipeline : FmhaFwdPipeline
F_score_mod_expr: str F_score_mod_expr: str
F_pre_softmax_expr:str
mask_impl : str mask_impl : str
@property @property
...@@ -390,7 +400,8 @@ class FmhaFwdKernel: ...@@ -390,7 +400,8 @@ class FmhaFwdKernel:
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_score_mod_expr = self.F_score_mod_expr) F_score_mod_expr = self.F_score_mod_expr,
F_pre_softmax_expr = self.F_pre_softmax_expr)
@property @property
def name(self) -> str: def name(self) -> str:
...@@ -445,7 +456,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -445,7 +456,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else: else:
return None return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str, pre_softmax_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future # support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
...@@ -514,6 +525,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e ...@@ -514,6 +525,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
F_tile=tile, F_tile=tile,
F_pipeline=pipeline, F_pipeline=pipeline,
F_score_mod_expr=score_mod_expr, F_score_mod_expr=score_mod_expr,
F_pre_softmax_expr=pre_softmax_expr,
mask_impl=mask_impl) mask_impl=mask_impl)
if kernel_filter != None: if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter): if not fnmatch.fnmatch(k.name, kernel_filter):
...@@ -536,15 +548,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: ...@@ -536,15 +548,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None: def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr) api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
for kernel in kernels: for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir) write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir) write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None: def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
with file_path.open('a') as f: with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr) _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
...@@ -1503,6 +1503,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1503,6 +1503,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
real_seqlen_k, real_seqlen_k,
mask.type == mask_enum::mask_top_left)); mask.type == mask_enum::mask_top_left));
} }
auto pre_softmax = [] (auto s) {
//ck_tile::detail::swallow(s);
return CK_PRE_SOFTMAX_F;
};
s_host_ref.ForEach([&](auto& self, auto i) {
auto new_val = pre_softmax(self(i));
self(i) = new_val;
});
if(lse) if(lse)
{ {
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>( ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
......
...@@ -30,7 +30,7 @@ handlers = dict( ...@@ -30,7 +30,7 @@ handlers = dict(
) )
assert 0 < len(handlers) assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None: def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
if output_dir is None: if output_dir is None:
output_dir = Path(__file__).parent output_dir = Path(__file__).parent
else: else:
...@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : ...@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.WRITE_BLOBS] handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr) handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
# list all the files that will be generated # list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None: def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
...@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter ...@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS] handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr) handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -113,9 +113,16 @@ if __name__ == "__main__": ...@@ -113,9 +113,16 @@ if __name__ == "__main__":
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables" help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
) )
parser.add_argument(
"--pre_softmax_expr",
default="s",
required=False,
help="flex attention's pre_softmax function, a cpp expression with `s` variable"
)
args = parser.parse_args() args = parser.parse_args()
api_list = args.direction.split(',') api_list = args.direction.split(',')
if args.list_blobs is not None: if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr) list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)
else: else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr) write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace ck_tile { namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_> template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_, typename PreSoftmaxFunction_>
struct FmhaFwdKernel struct FmhaFwdKernel
{ {
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
...@@ -1316,6 +1316,12 @@ struct FmhaFwdKernel ...@@ -1316,6 +1316,12 @@ struct FmhaFwdKernel
return new_score; return new_score;
}; };
auto pre_softmax_def = PreSoftmaxFunction_{};
auto pre_softmax_arg = [pre_softmax_def](
typename PreSoftmaxFunction_::TScore s) {
return pre_softmax_def(s);
};
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -1330,8 +1336,8 @@ struct FmhaFwdKernel ...@@ -1330,8 +1336,8 @@ struct FmhaFwdKernel
identity{}, // bias_element_func identity{}, // bias_element_func
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
identity{}, // lse_element_func identity{}, // lse_element_func
identity{}, // s_acc_element_func pre_softmax_arg, // s_acc_element_func
score_mod_arg, score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
...@@ -1354,7 +1360,7 @@ struct FmhaFwdKernel ...@@ -1354,7 +1360,7 @@ struct FmhaFwdKernel
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
identity{}, identity{},
identity{}, pre_softmax_arg,
score_mod_arg, score_mod_arg,
identity{}, identity{},
identity{}, identity{},
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace ck_tile { namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, typename PreSoftmaxFunction_>
struct BlockFmhaPipelineQRKSVSAsync struct BlockFmhaPipelineQRKSVSAsync
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
...@@ -759,6 +759,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -759,6 +759,7 @@ struct BlockFmhaPipelineQRKSVSAsync
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
k_dram_block_window_tmp, k_dram_block_window_tmp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment