Unverified Commit 27a2a0a1 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

abstract score_mod from a pipeline

parent 529bda90
......@@ -1302,6 +1302,21 @@ struct FmhaFwdKernel
}
}();
auto score_mod_def = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) h; (void) b;
return s + static_cast<decltype(s)>(q_idx - v_idx);
};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](auto s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx);
};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
......@@ -1318,6 +1333,7 @@ struct FmhaFwdKernel
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
......@@ -1329,11 +1345,20 @@ struct FmhaFwdKernel
else
{
return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window,
identity{},
v_dram_window,
identity{},
bias_dram_window,
identity{},
randval_dram_window,
lse_dram_window,
identity{},
identity{},
score_mod_arg,
identity{},
identity{},
mask,
position_encoding,
kargs.scale_s,
......
......@@ -120,6 +120,7 @@ struct BlockFmhaPipelineQRKSVS
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename ScoreModFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
......@@ -136,6 +137,7 @@ struct BlockFmhaPipelineQRKSVS
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const ScoreModFunction& score_mod,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
......@@ -339,16 +341,6 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
......@@ -361,10 +353,7 @@ struct BlockFmhaPipelineQRKSVS
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0;
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
});
});
}
......@@ -634,47 +623,6 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
......@@ -138,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename ScoreModFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
......@@ -154,6 +155,7 @@ struct BlockFmhaPipelineQRKSVSAsync
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const ScoreModFunction& score_mod,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
......@@ -409,16 +411,6 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
......@@ -431,10 +423,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0;
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
});
});
}
......@@ -770,47 +759,6 @@ struct BlockFmhaPipelineQRKSVSAsync
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment