Unverified Commit 8c96b18b authored by Max Podkorytov's avatar Max Podkorytov
Browse files

use custom score for testing

parent 6ef86201
......@@ -1375,9 +1375,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::scales(scale_s));
auto score_mod = [] (auto score, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) score; (void) b; (void) h; (void) q_idx; (void) v_idx;
return score;
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) s; (void) b; (void) h; (void) q_idx; (void) v_idx;
return s + static_cast<decltype(s)>(q_idx - v_idx);
};
s_host_ref.ForEach([&](auto& self, auto i) {
......
......@@ -108,9 +108,9 @@ if __name__ == "__main__":
parser.add_argument(
"--score_mod_expr",
default="s",
# default="s",
# test with
# default="s + static_cast<decltype(s)>(q_idx - v_idx)"
default="s + static_cast<decltype(s)>(q_idx - v_idx)",
required=False,
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment