Unverified Commit 37052173 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

run clang-format

parent 0d96a891
......@@ -850,8 +850,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else // fmha_fwd_traits or fmha_splitkv_traits
{
// traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
// traits.has_lse = lse;
// traits.do_fp8_static_quant = squant;
......@@ -1375,20 +1375,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::identity{});
#ifndef CK_TILE_SCORE_MOD_F
#error "must be defined"
#else
#define XSTR(x) STR(x)
#define STR(x) #x
#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F)
#endif
#ifndef CK_TILE_SCORE_MOD_F
#error "must be defined"
#else
#define XSTR(x) STR(x)
#define STR(x) #x
#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F)
#endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
ck_tile::detail::swallow(s, b, h, q_idx, v_idx);
return CK_TILE_SCORE_MOD_F;
};
s_host_ref.ForEach([&](auto& self, auto i) {
s_host_ref.ForEach([&](auto& self, auto i) {
auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]);
// printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n",
// wb, i[0], i[1], i[2], self(i), new_score);
......@@ -1396,9 +1400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
});
auto scale_def = ck_tile::scales(scale_s);
s_host_ref.ForEach([&](auto& self, auto i) {
self(i) = scale_def(self(i));
});
s_host_ref.ForEach([&](auto& self, auto i) { self(i) = scale_def(self(i)); });
if(bias.type == bias_enum::elementwise_bias)
{
......
......@@ -1305,15 +1305,16 @@ struct FmhaFwdKernel
// may have state inside
auto score_mod_def = ScoreModFunction_{};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](
typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
auto new_score = score_mod_def(s, b, h, q_idx, v_idx);\
// printf("device score_mod at (%d %d %d %d), score before: %f, score after: %f score_clip: %f\n",
// b, h, q_idx, v_idx, s, new_score, new_score_after_clip);
return new_score;
};
auto score_mod_arg =
[b = i_batch, h = i_nhead, score_mod_def](typename ScoreModFunction_::TScore s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
auto new_score = score_mod_def(
s, b, h, q_idx, v_idx); // printf("device score_mod at (%d %d %d %d), score
// before: %f, score after: %f score_clip: %f\n",
// b, h, q_idx, v_idx, s, new_score, new_score_after_clip);
return new_score;
};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
......@@ -1329,8 +1330,8 @@ struct FmhaFwdKernel
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
......
......@@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
const auto k_origin = k_dram_block_window.get_window_origin();
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
......@@ -352,7 +352,7 @@ struct BlockFmhaPipelineQRKSVS
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<1>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
});
});
}
......
......@@ -411,7 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
const auto k_origin = k_dram_block_window.get_window_origin();
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment