Unverified Commit 91e1a796 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

move elementwise functors

parent 052a7265
......@@ -1302,24 +1302,33 @@ struct FmhaFwdKernel
}
}();
constexpr auto q_arg_element_func = identity{};
constexpr auto k_arg_element_func = identity{};
constexpr auto v_arg_element_func = identity{};
constexpr auto bias_arg_element_func = identity{};
constexpr auto lse_arg_element_func = identity{};
constexpr auto s_acc_arg_element_func = identity{};
constexpr auto p_compute_arg_element_func = identity{};
constexpr auto o_acc_arg_element_func = identity{};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
q_arg_element_func,
k_dram_window,
identity{}, // k_element_func
k_arg_element_func,
v_dram_window,
identity{}, // v_element_func
v_arg_element_func,
bias_dram_window,
identity{}, // bias_element_func
bias_arg_element_func,
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
lse_arg_element_func,
s_acc_arg_element_func,
composes(p_compute_arg_element_func, scales{kargs.scale_p}),
composes(o_acc_arg_element_func, saturates<fp8_t>{}, scales{kargs.scale_o}),
mask,
position_encoding,
kargs.scale_s,
......@@ -1329,11 +1338,19 @@ struct FmhaFwdKernel
else
{
return FmhaPipeline{}(q_dram_window,
q_arg_element_func,
k_dram_window,
k_arg_element_func,
v_dram_window,
v_arg_element_func,
bias_dram_window,
bias_arg_element_func,
randval_dram_window,
lse_dram_window,
lse_arg_element_func,
s_acc_arg_element_func,
p_compute_arg_element_func,
o_acc_arg_element_func,
mask,
position_encoding,
kargs.scale_s,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment