Unverified Commit 91e1a796 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

move elementwise functors

parent 052a7265
...@@ -1302,24 +1302,33 @@ struct FmhaFwdKernel ...@@ -1302,24 +1302,33 @@ struct FmhaFwdKernel
} }
}(); }();
constexpr auto q_arg_element_func = identity{};
constexpr auto k_arg_element_func = identity{};
constexpr auto v_arg_element_func = identity{};
constexpr auto bias_arg_element_func = identity{};
constexpr auto lse_arg_element_func = identity{};
constexpr auto s_acc_arg_element_func = identity{};
constexpr auto p_compute_arg_element_func = identity{};
constexpr auto o_acc_arg_element_func = identity{};
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
return FmhaPipeline{}( return FmhaPipeline{}(
q_dram_window, q_dram_window,
identity{}, // q_element_func q_arg_element_func,
k_dram_window, k_dram_window,
identity{}, // k_element_func k_arg_element_func,
v_dram_window, v_dram_window,
identity{}, // v_element_func v_arg_element_func,
bias_dram_window, bias_dram_window,
identity{}, // bias_element_func bias_arg_element_func,
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
identity{}, // lse_element_func lse_arg_element_func,
identity{}, // s_acc_element_func s_acc_arg_element_func,
scales{kargs.scale_p}, // p_compute_element_func composes(p_compute_arg_element_func, scales{kargs.scale_p}),
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(o_acc_arg_element_func, saturates<fp8_t>{}, scales{kargs.scale_o}),
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
...@@ -1329,11 +1338,19 @@ struct FmhaFwdKernel ...@@ -1329,11 +1338,19 @@ struct FmhaFwdKernel
else else
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
q_arg_element_func,
k_dram_window, k_dram_window,
k_arg_element_func,
v_dram_window, v_dram_window,
v_arg_element_func,
bias_dram_window, bias_dram_window,
bias_arg_element_func,
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
lse_arg_element_func,
s_acc_arg_element_func,
p_compute_arg_element_func,
o_acc_arg_element_func,
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment