Unverified Commit 37052173 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

run clang-format

parent 0d96a891
...@@ -850,8 +850,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -850,8 +850,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else // fmha_fwd_traits or fmha_splitkv_traits else // fmha_fwd_traits or fmha_splitkv_traits
{ {
// traits.is_group_mode = (mode == mode_enum::group); // traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type; traits.mask_type = mask.type;
traits.bias_type = bias.type; traits.bias_type = bias.type;
// traits.has_lse = lse; // traits.has_lse = lse;
// traits.do_fp8_static_quant = squant; // traits.do_fp8_static_quant = squant;
...@@ -1375,20 +1375,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1375,20 +1375,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{}, ck_tile::identity{},
ck_tile::identity{}); ck_tile::identity{});
#ifndef CK_TILE_SCORE_MOD_F #ifndef CK_TILE_SCORE_MOD_F
#error "must be defined" #error "must be defined"
#else #else
#define XSTR(x) STR(x) #define XSTR(x) STR(x)
#define STR(x) #x #define STR(x) #x
#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F) #pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F)
#endif #endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) { auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
ck_tile::detail::swallow(s, b, h, q_idx, v_idx); ck_tile::detail::swallow(s, b, h, q_idx, v_idx);
return CK_TILE_SCORE_MOD_F; return CK_TILE_SCORE_MOD_F;
}; };
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) {
auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]); auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]);
// printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n", // printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n",
// wb, i[0], i[1], i[2], self(i), new_score); // wb, i[0], i[1], i[2], self(i), new_score);
...@@ -1396,9 +1400,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1396,9 +1400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}); });
auto scale_def = ck_tile::scales(scale_s); auto scale_def = ck_tile::scales(scale_s);
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) { self(i) = scale_def(self(i)); });
self(i) = scale_def(self(i));
});
if(bias.type == bias_enum::elementwise_bias) if(bias.type == bias_enum::elementwise_bias)
{ {
......
...@@ -1305,15 +1305,16 @@ struct FmhaFwdKernel ...@@ -1305,15 +1305,16 @@ struct FmhaFwdKernel
// may have state inside // may have state inside
auto score_mod_def = ScoreModFunction_{}; auto score_mod_def = ScoreModFunction_{};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def]( auto score_mod_arg =
typename ScoreModFunction_::TScore s, [b = i_batch, h = i_nhead, score_mod_def](typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx, ck_tile::index_t q_idx,
ck_tile::index_t v_idx) { ck_tile::index_t v_idx) {
auto new_score = score_mod_def(s, b, h, q_idx, v_idx);\ auto new_score = score_mod_def(
// printf("device score_mod at (%d %d %d %d), score before: %f, score after: %f score_clip: %f\n", s, b, h, q_idx, v_idx); // printf("device score_mod at (%d %d %d %d), score
// b, h, q_idx, v_idx, s, new_score, new_score_after_clip); // before: %f, score after: %f score_clip: %f\n",
return new_score; // b, h, q_idx, v_idx, s, new_score, new_score_after_clip);
}; return new_score;
};
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
...@@ -1329,8 +1330,8 @@ struct FmhaFwdKernel ...@@ -1329,8 +1330,8 @@ struct FmhaFwdKernel
identity{}, // bias_element_func identity{}, // bias_element_func
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
identity{}, // lse_element_func identity{}, // lse_element_func
identity{}, // s_acc_element_func identity{}, // s_acc_element_func
score_mod_arg, score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
......
...@@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores // FlexAttention score modifier operates directly on scores
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
...@@ -352,7 +352,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -352,7 +352,7 @@ struct BlockFmhaPipelineQRKSVS
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
}); });
}); });
} }
......
...@@ -411,7 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -411,7 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores // FlexAttention score modifier operates directly on scores
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment