Unverified Commit 0d96a891 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

fix numeric mismatches

parent ee0654c4
......@@ -7,7 +7,7 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
endif()
variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>(q_idx - v_idx)]])
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]])
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
......
......@@ -1384,17 +1384,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
#endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) s; (void) b; (void) h; (void) q_idx; (void) v_idx;
ck_tile::detail::swallow(s, b, h, q_idx, v_idx);
return CK_TILE_SCORE_MOD_F;
};
s_host_ref.ForEach([&](auto& self, auto i) {
self(i) = score_mod(self(i), i[0], i[1], i[2], i[3]);
auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]);
// printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n",
// wb, i[0], i[1], i[2], self(i), new_score);
self(i) = new_score;
});
auto scale_def = ck_tile::scales(scale_s);
s_host_ref.ForEach([&](auto& self, auto i) {
scale_def(self(i));
self(i) = scale_def(self(i));
});
if(bias.type == bias_enum::elementwise_bias)
......
......@@ -1309,7 +1309,10 @@ struct FmhaFwdKernel
typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx);
auto new_score = score_mod_def(s, b, h, q_idx, v_idx);\
// printf("device score_mod at (%d %d %d %d), score before: %f, score after: %f score_clip: %f\n",
// b, h, q_idx, v_idx, s, new_score, new_score_after_clip);
return new_score;
};
auto o_acc_tile = [&]() {
......
......@@ -350,9 +350,8 @@ struct BlockFmhaPipelineQRKSVS
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
});
});
......
......@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment