"...composable_kernel.git" did not exist on "3ec6360e22fed183113aa89d2773e1fc3b969916"
Unverified Commit 27a2a0a1 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

abstract score_mod from a pipeline

parent 529bda90
...@@ -1302,6 +1302,21 @@ struct FmhaFwdKernel ...@@ -1302,6 +1302,21 @@ struct FmhaFwdKernel
} }
}(); }();
auto score_mod_def = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) h; (void) b;
return s + static_cast<decltype(s)>(q_idx - v_idx);
};
auto score_mod_arg = [b=i_batch, h=i_nhead, score_mod_def](auto s,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
return score_mod_def(s, b, h, q_idx, v_idx);
};
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -1318,6 +1333,7 @@ struct FmhaFwdKernel ...@@ -1318,6 +1333,7 @@ struct FmhaFwdKernel
lse_dram_window, lse_dram_window,
identity{}, // lse_element_func identity{}, // lse_element_func
identity{}, // s_acc_element_func identity{}, // s_acc_element_func
score_mod_arg,
scales{kargs.scale_p}, // p_compute_element_func scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask, mask,
...@@ -1329,11 +1345,20 @@ struct FmhaFwdKernel ...@@ -1329,11 +1345,20 @@ struct FmhaFwdKernel
else else
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window, k_dram_window,
identity{},
v_dram_window, v_dram_window,
identity{},
bias_dram_window, bias_dram_window,
identity{},
randval_dram_window, randval_dram_window,
lse_dram_window, lse_dram_window,
identity{},
identity{},
score_mod_arg,
identity{},
identity{},
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
......
...@@ -120,6 +120,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -120,6 +120,7 @@ struct BlockFmhaPipelineQRKSVS
typename BiasElementFunction, typename BiasElementFunction,
typename LSEElementFunction, typename LSEElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename ScoreModFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction, typename OAccElementFunction,
typename PositionEncoding> typename PositionEncoding>
...@@ -136,6 +137,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -136,6 +137,7 @@ struct BlockFmhaPipelineQRKSVS
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
const ScoreModFunction& score_mod,
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
...@@ -339,16 +341,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -339,16 +341,6 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores // FlexAttention score modifier operates directly on scores
{ {
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
...@@ -361,10 +353,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -361,10 +353,7 @@ struct BlockFmhaPipelineQRKSVS
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0; s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
}); });
}); });
} }
...@@ -634,47 +623,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -634,47 +623,6 @@ struct BlockFmhaPipelineQRKSVS
return o_acc; return o_acc;
} }
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -138,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -138,6 +138,7 @@ struct BlockFmhaPipelineQRKSVSAsync
typename BiasElementFunction, typename BiasElementFunction,
typename LSEElementFunction, typename LSEElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename ScoreModFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction, typename OAccElementFunction,
typename PositionEncoding> typename PositionEncoding>
...@@ -154,6 +155,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -154,6 +155,7 @@ struct BlockFmhaPipelineQRKSVSAsync
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
const ScoreModFunction& score_mod,
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
...@@ -409,16 +411,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -409,16 +411,6 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores // FlexAttention score modifier operates directly on scores
{ {
auto score_mod = [](auto s,
ck_tile::index_t b,
ck_tile::index_t h,
ck_tile::index_t q_idx,
ck_tile::index_t v_idx) {
(void) s;
(void) b;
(void) h;
return static_cast<decltype(s)>(q_idx - v_idx);
};
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
...@@ -431,10 +423,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -431,10 +423,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto b = 0; s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col);
const auto h = 0;
s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), b, h, row, col);
}); });
}); });
} }
...@@ -770,47 +759,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -770,47 +759,6 @@ struct BlockFmhaPipelineQRKSVSAsync
return o_acc; return o_acc;
} }
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment