Unverified Commit 2a198f14 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

add a hardcoded score_mod

parent cd69c852
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <ostream> #include <ostream>
#include <string> #include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/flex_fmha.hpp"
// keep sync with BlockAttentionBiasEnum // keep sync with BlockAttentionBiasEnum
enum class bias_enum enum class bias_enum
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <string> #include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/flex_fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum // keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum enum class mask_enum
......
...@@ -337,6 +337,38 @@ struct BlockFmhaPipelineQRKSVS ...@@ -337,6 +337,38 @@ struct BlockFmhaPipelineQRKSVS
} }
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0;
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
});
});
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......
...@@ -407,6 +407,38 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -407,6 +407,38 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1); __builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0;
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
});
});
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment