"...composable_kernel_rocm.git" did not exist on "dbb7002d521c905e23bba79dec257f5fd1276b86"
Unverified Commit 503a7da6 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

pipeline fixes for accuracy issues; disable pre-softmax function until its accuracy is fixed

parent d480a5a6
...@@ -8,9 +8,11 @@ endif() ...@@ -8,9 +8,11 @@ endif()
variable_watch(FMHA_SCORE_MOD_F) variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]]) set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
# set(FMHA_SCORE_MOD_F [[s]])
variable_watch(FMHA_PRE_SOFTMAX_F) variable_watch(FMHA_PRE_SOFTMAX_F)
set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]]) # set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
set(FMHA_PRE_SOFTMAX_F [[s]])
foreach(api ${FMHA_FWD_ENABLE_APIS}) foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
......
...@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRKSVS
s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
}); });
......
...@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment