Unverified Commit 503a7da6 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

pipeline fixes for accuracy issues; disable pre-softmax function until its accuracy is fixed

parent d480a5a6
......@@ -8,9 +8,11 @@ endif()
variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
# set(FMHA_SCORE_MOD_F [[s]])
variable_watch(FMHA_PRE_SOFTMAX_F)
set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
# set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
set(FMHA_PRE_SOFTMAX_F [[s]])
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
......
......@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRKSVS
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
});
......
......@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment