Commit 4da5206d authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Revert "qsksvs pipeline changes to mirror qrksvs"

This reverts commit f7942b99.
parent 1862b27f
...@@ -95,8 +95,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -95,8 +95,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{ {
constexpr std::array occupancy{2, 2, 2, 2, 2, 1}; constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2]; return occupancy[detail::log2<kMaxSplits>::value - 2];
} else if constexpr(kHeadDimV <= 512) {
return 1;
} }
} }
}(); }();
......
...@@ -96,10 +96,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -96,10 +96,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
return 1; return 1;
} }
else if constexpr(kQKHeaddim <= 512)
{
return 1;
}
} }
}(); }();
......
...@@ -12,7 +12,7 @@ namespace ck_tile { ...@@ -12,7 +12,7 @@ namespace ck_tile {
/// NOTICE: we no-longer use this pipeline. /// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct BlockFmhaPipelineQSKSVS struct [[deprecated]] BlockFmhaPipelineQSKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -51,24 +51,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -51,24 +51,6 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -99,9 +81,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -99,9 +81,6 @@ struct BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs"; static constexpr const char* name = "qs";
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
using DropoutType = int32_t; // unused
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -116,7 +95,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -116,7 +95,6 @@ struct BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -128,23 +106,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -128,23 +106,6 @@ struct BlockFmhaPipelineQSKSVS
typename OAccElementFunction, typename OAccElementFunction,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
// const QElementFunction& q_element_func,
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
// const KElementFunction& k_element_func,
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
// const VElementFunction& v_element_func,
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
// const BiasElementFunction& bias_element_func,
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
// const LSEElementFunction& lse_element_func,
// const SAccElementFunction& s_acc_element_func,
// const PComputeElementFunction& p_compute_element_func,
// const OAccElementFunction& o_acc_element_func,
// FmhaMask mask,
// PositionEncoding position_encoding,
// float scale_s,
// void* smem_ptr) const
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
...@@ -153,7 +114,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -153,7 +114,6 @@ struct BlockFmhaPipelineQSKSVS
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -162,8 +122,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -162,8 +122,7 @@ struct BlockFmhaPipelineQSKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
DropoutType& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
...@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
/// NOTICE: we no-longer use this policy. /// NOTICE: we no-longer use this policy.
template <> template <>
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{ {
static constexpr bool QLoadOnce = false; static constexpr bool QLoadOnce = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment