Unverified Commit bca939ce authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Add pre_softmax fnctor (#1852)



* Add pre_softmax fnctor

* remove stray define:wq

* Move op out of pipeline, adds it to refnc

---------
Co-authored-by: default avatarroot <root@splinter-126-wr-d1.aus.dcgpu>
Co-authored-by: default avatarMax Podkorytov <4273004+tenpercent@users.noreply.github.com>
parent 81e00bce
......@@ -9,6 +9,9 @@ endif()
variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
variable_watch(FMHA_PRE_SOFTMAX_F)
set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
......@@ -42,6 +45,7 @@ add_custom_command(
--api ${FMHA_FWD_APIS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
"--score_mod_expr=${FMHA_SCORE_MOD_F}"
"--pre_softmax_expr=${FMHA_PRE_SOFTMAX_F}"
VERBATIM
)
......@@ -88,6 +92,7 @@ endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}")
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_PRE_SOFTMAX_F=${FMHA_PRE_SOFTMAX_F}")
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
......
......@@ -96,8 +96,17 @@ struct score_mod_def_{F_idx} {{
}}
}};
struct pre_softmax_def_{F_idx} {{
using TScore = typename fmha_pipeline_{F_idx}::SaccDataType;
CK_TILE_HOST_DEVICE TScore operator()(TScore s
) const {{
(void) s;
return {F_pre_softmax_expr};
}}
}};
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}>;
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}, score_mod_def_{F_idx}, pre_softmax_def_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
......@@ -348,6 +357,7 @@ class FmhaFwdKernel:
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
F_score_mod_expr: str
F_pre_softmax_expr:str
mask_impl : str
@property
......@@ -390,7 +400,8 @@ class FmhaFwdKernel:
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_score_mod_expr = self.F_score_mod_expr)
F_score_mod_expr = self.F_score_mod_expr,
F_pre_softmax_expr = self.F_pre_softmax_expr)
@property
def name(self) -> str:
......@@ -445,7 +456,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str, pre_softmax_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
......@@ -514,6 +525,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_e
F_tile=tile,
F_pipeline=pipeline,
F_score_mod_expr=score_mod_expr,
F_pre_softmax_expr=pre_softmax_expr,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
......@@ -536,15 +548,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr)
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
......@@ -1503,6 +1503,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
auto pre_softmax = [] (auto s) {
//ck_tile::detail::swallow(s);
return CK_PRE_SOFTMAX_F;
};
s_host_ref.ForEach([&](auto& self, auto i) {
auto new_val = pre_softmax(self(i));
self(i) = new_val;
});
if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
......
......@@ -30,7 +30,7 @@ handlers = dict(
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
......@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
for api in api_list:
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr)
handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr) -> None:
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
assert output_file is not None
file_path = Path(output_file)
......@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr)
handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -113,9 +113,16 @@ if __name__ == "__main__":
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
parser.add_argument(
"--pre_softmax_expr",
default="s",
required=False,
help="flex attention's pre_softmax function, a cpp expression with `s` variable"
)
args = parser.parse_args()
api_list = args.direction.split(',')
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)
else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr)
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)
......@@ -20,7 +20,7 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_>
template <typename FmhaPipeline_, typename EpiloguePipeline_, typename ScoreModFunction_, typename PreSoftmaxFunction_>
struct FmhaFwdKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
......@@ -1316,6 +1316,12 @@ struct FmhaFwdKernel
return new_score;
};
auto pre_softmax_def = PreSoftmaxFunction_{};
auto pre_softmax_arg = [pre_softmax_def](
typename PreSoftmaxFunction_::TScore s) {
return pre_softmax_def(s);
};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
......@@ -1331,7 +1337,7 @@ struct FmhaFwdKernel
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
pre_softmax_arg, // s_acc_element_func
score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
......@@ -1354,7 +1360,7 @@ struct FmhaFwdKernel
randval_dram_window,
lse_dram_window,
identity{},
identity{},
pre_softmax_arg,
score_mod_arg,
identity{},
identity{},
......
......@@ -13,7 +13,7 @@
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, typename PreSoftmaxFunction_>
struct BlockFmhaPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
......@@ -759,6 +759,7 @@ struct BlockFmhaPipelineQRKSVSAsync
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment