"tests/pipelines/vscode:/vscode.git/clone" did not exist on "beb848e2b6cc888bd5039e6f6cac7c932c6c3225"
Unverified Commit 0d96a891 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

fix numeric mismatches

parent ee0654c4
...@@ -7,7 +7,7 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all") ...@@ -7,7 +7,7 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
endif() endif()
variable_watch(FMHA_SCORE_MOD_F) variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>(q_idx - v_idx)]]) set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
foreach(api ${FMHA_FWD_ENABLE_APIS}) foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
......
...@@ -1384,17 +1384,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1384,17 +1384,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
#endif #endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) { auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) s; (void) b; (void) h; (void) q_idx; (void) v_idx; ck_tile::detail::swallow(s, b, h, q_idx, v_idx);
return CK_TILE_SCORE_MOD_F; return CK_TILE_SCORE_MOD_F;
}; };
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) {
self(i) = score_mod(self(i), i[0], i[1], i[2], i[3]); auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]);
// printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n",
// wb, i[0], i[1], i[2], self(i), new_score);
self(i) = new_score;
}); });
auto scale_def = ck_tile::scales(scale_s); auto scale_def = ck_tile::scales(scale_s);
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) {
scale_def(self(i)); self(i) = scale_def(self(i));
}); });
if(bias.type == bias_enum::elementwise_bias) if(bias.type == bias_enum::elementwise_bias)
......
...@@ -1309,7 +1309,10 @@ struct FmhaFwdKernel ...@@ -1309,7 +1309,10 @@ struct FmhaFwdKernel
typename ScoreModFunction_::TScore s, typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx, ck_tile::index_t q_idx,
ck_tile::index_t v_idx) { ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx); auto new_score = score_mod_def(s, b, h, q_idx, v_idx);\
// printf("device score_mod at (%d %d %d %d), score before: %f, score after: %f score_clip: %f\n",
// b, h, q_idx, v_idx, s, new_score, new_score_after_clip);
return new_score;
}; };
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
......
...@@ -350,9 +350,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -350,9 +350,8 @@ struct BlockFmhaPipelineQRKSVS
s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
}); });
}); });
......
...@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment