Commit 21dc4596 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Remove KLoadOnce and use NuPrefetchK > 1

parent 00fe0752
...@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel ...@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kKLoadOnce) return pad_tensor_view(
{ k_dram_naive,
return pad_tensor_view( make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
k_dram_naive, sequence<false, kPadHeadDimQ>{});
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
}
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel ...@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
[&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
......
...@@ -14,10 +14,12 @@ namespace ck_tile { ...@@ -14,10 +14,12 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false, /* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false, /* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
template <typename Problem> template <typename Problem>
......
...@@ -12,6 +12,7 @@ namespace ck_tile { ...@@ -12,6 +12,7 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false, /* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
...@@ -35,9 +35,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -35,9 +35,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
......
...@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = true;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -154,16 +151,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -154,16 +151,18 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim == kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumKLdsBuffers >= 2, "At least two LDS buffers needed for K");
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
...@@ -181,8 +180,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -181,8 +180,8 @@ struct BlockFmhaPipelineQRKSVSAsync
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr); KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>( auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window = auto k_lds_window = make_tile_window(
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kSubQKHeaddim>{}), {0, 0}); k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
...@@ -258,7 +257,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -258,7 +257,12 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
auto k_tile = load_tile(k_dram_window); // prefetch two K tiles
auto k_tile_0 = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
auto k_tile_1 = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -299,7 +303,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -299,7 +303,6 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
...@@ -310,42 +313,76 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -310,42 +313,76 @@ struct BlockFmhaPipelineQRKSVSAsync
// ensure loading of Q from LDS completely done // ensure loading of Q from LDS completely done
block_sync_lds(); block_sync_lds();
do __builtin_amdgcn_sched_barrier(0);
{
store_tile(k_lds_window, k_tile);
__builtin_amdgcn_sched_barrier(0); // store first K tile to LDS
auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tile_0);
do
{
// STAGE 1, QK gemm // STAGE 1, QK gemm
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
if(i_total_loops < num_total_loop - 1) static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
{ if constexpr(i_k0 > 0 && i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {kN0, 0}); {
k_tile = load_tile(k_dram_window); if constexpr(i_k0 % 2 == 1)
} k_tile_0 = load_tile(k_dram_window);
else
k_tile_1 = load_tile(k_dram_window);
__builtin_amdgcn_sched_barrier(0); move_tile_window(k_dram_window, {0, kK0});
};
__builtin_amdgcn_sched_barrier(0);
// ensure K data needed by this gemm iteration completely available on LDS
block_sync_lds();
k_lds_window_tmp =
get_slice_tile(k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
// store K data needed by next gemm iteration to LDS
if constexpr(i_k0 % 2 == 0)
store_tile(k_lds_window_tmp, tile_elementwise_in(k_element_func, k_tile_1));
else
store_tile(k_lds_window_tmp, tile_elementwise_in(k_element_func, k_tile_0));
__builtin_amdgcn_sched_barrier(0);
gemm_0(
s_acc,
get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{}));
__builtin_amdgcn_sched_barrier(0);
});
// ensure k is completely updated on LDS
block_sync_lds(); block_sync_lds();
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops gemm_0(s_acc,
if constexpr(kQKHeaddim == kSubQKHeaddim) get_slice_tile(
{ q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
gemm_0(s_acc, q, k_lds_window); get_slice_tile(k_lds_window,
} sequence<((k0_loops - 1) % NumKLdsBuffers) * kN0, 0>{},
else sequence<(((k0_loops - 1) % NumKLdsBuffers) + 1) * kN0, kK0>{}));
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
if(i_total_loops < num_total_loop - 1)
{ {
move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { k_tile_0 = load_tile(k_dram_window);
gemm_0(s_acc, move_tile_window(k_dram_window, {0, kK0});
get_slice_tile( k_tile_1 = load_tile(k_dram_window);
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), move_tile_window(k_dram_window, {0, kK0});
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
});
} }
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads __builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
...@@ -427,8 +464,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -427,8 +464,10 @@ struct BlockFmhaPipelineQRKSVSAsync
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{}); block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1} const auto m_old = m; // m{j-1}
tile_elementwise_inout( tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); },
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} m,
m_old,
m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>( auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j} s.get_tile_distribution()); // Pcompute{j}
...@@ -641,8 +680,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -641,8 +680,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}); });
} }
} }
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
...@@ -653,7 +691,16 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -653,7 +691,16 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<((k1_loops - 1) % NumVLdsBuffers) * kN1, 0>{}, sequence<((k1_loops - 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((k1_loops - 1) % NumVLdsBuffers) + 1) * kN1, kK1>{})); sequence<(((k1_loops - 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}));
} }
} while(++i_total_loops < num_total_loop);
__builtin_amdgcn_sched_barrier(0);
if(i_total_loops++ < num_total_loop)
{
k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tile_0);
}
} while(i_total_loops < num_total_loop);
// store lse // store lse
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true, /* AsyncCopy = */ true,
/* NumPrefetchK = */ 2,
/* NumPrefetchV = */ 2> /* NumPrefetchV = */ 2>
{ {
template <typename Problem> template <typename Problem>
...@@ -60,20 +61,16 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -60,20 +61,16 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim == using GemmProblem =
Problem::BlockFmhaShape::kSubQKHeaddim) BlockGemmProblem<typename Problem::QDataType,
? Problem::BlockFmhaShape::kSubQKHeaddim typename Problem::KDataType,
: Problem::BlockFmhaShape::kK0; typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
using GemmProblem = BlockGemmProblem< TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
typename Problem::QDataType, Problem::BlockFmhaShape::kN0,
typename Problem::KDataType, Problem::BlockFmhaShape::kK0>,
typename Problem::SaccDataType, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
Problem::kNumGemm0Warps * get_warp_size(), typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0, BlockGemmK>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
using BlockFmhaPipelineQRKSVSDefaultPolicy = using BlockFmhaPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false, /* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
} // namespace ck_tile } // namespace ck_tile
...@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -34,9 +34,6 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kQLoadOnce = false; static constexpr bool kQLoadOnce = false;
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
......
...@@ -12,6 +12,7 @@ namespace ck_tile { ...@@ -12,6 +12,7 @@ namespace ck_tile {
struct BlockFmhaPipelineQSKSVSDefaultPolicy struct BlockFmhaPipelineQSKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopy = */ false, /* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
...@@ -276,21 +276,26 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -276,21 +276,26 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
} }
}; };
template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchV_> template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLoadOnce_> struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>
{ {
static constexpr index_t NumPrefetchK = NumPrefetchK_;
static constexpr index_t NumPrefetchV = NumPrefetchV_; static constexpr index_t NumPrefetchV = NumPrefetchV_;
// 1) When Async == true, we preload whole K-tile for next iteration using single LDS buffer,
// and preload V-slice for next unroll using multiple LDS buffers
// 2) When Async == false, we preload K-slice for next unroll using single LDS buffer, and
// preload V-slice for next unroll using single LDS buffer
static constexpr bool AsyncCopy = AsyncCopy_; static constexpr bool AsyncCopy = AsyncCopy_;
static constexpr bool KLoadOnce = AsyncCopy;
using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>; using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
return min(NumPrefetchK, k0_loops);
}
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumVLdsBuffers() CK_TILE_DEVICE static constexpr auto GetNumVLdsBuffers()
{ {
...@@ -317,8 +322,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -317,8 +322,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
...@@ -382,29 +386,51 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -382,29 +386,51 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return WG::WarpGemmAttribute::Impl::kCM1PerLane; return WG::WarpGemmAttribute::Impl::kCM1PerLane;
} }
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
number<kKPack>{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{},
sequence<1>{}));
return k_lds_block_desc;
}
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t SingleKSize = [&]() {
constexpr index_t kKPerBlock = using KDataType = remove_cvref_t<typename Problem::KDataType>;
KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( static_assert(PixelsPerRow % kKPack == 0);
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}), constexpr index_t NPerRow = PixelsPerRow / kKPack;
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}), constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
number<8>{}, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
number<1>{}); static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto k_lds_block_desc = transform_tensor_descriptor( return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
k_lds_block_desc_0, }();
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc; return SingleKSize;
} }
template <typename Problem> template <typename Problem>
...@@ -428,6 +454,48 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -428,6 +454,48 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return SingleVSize; return SingleVSize;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetKSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumKLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
...@@ -532,8 +600,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -532,8 +600,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment