Commit 916daf59 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use k0_loops small tile load/store to replace the big tile load/store for K

parent 4776c8c0
...@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel ...@@ -1082,20 +1082,10 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kKLoadOnce)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
}
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel ...@@ -1147,15 +1137,7 @@ struct FmhaFwdKernel
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
[&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
......
...@@ -154,14 +154,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -154,14 +154,18 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim == kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -257,8 +261,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -257,8 +261,18 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto k_tile = load_tile(k_dram_window); using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k0_loops> k_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[i_k0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -301,18 +315,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -301,18 +315,18 @@ struct BlockFmhaPipelineQRKSVSAsync
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
// ensure loading of Q from LDS completely done // ensure loading of Q from LDS completely done
block_sync_lds(); block_sync_lds();
do do
{ {
store_tile(k_lds_window, k_tile); static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = get_slice_tile(
k_lds_window, sequence<i_k0 * kN0, 0>{}, sequence<(i_k0 + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[i_k0]);
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -322,7 +336,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -322,7 +336,14 @@ struct BlockFmhaPipelineQRKSVSAsync
if(i_total_loops < num_total_loop - 1) if(i_total_loops < num_total_loop - 1)
{ {
move_tile_window(k_dram_window, {kN0, 0}); move_tile_window(k_dram_window, {kN0, 0});
k_tile = load_tile(k_dram_window);
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[i_k0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -335,8 +356,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -335,8 +356,8 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc, s_acc,
get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window, get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{}, sequence<i_k0 * kN0, 0>{},
sequence<kN0, (i_k0 + 1) * kK0>{})); sequence<(i_k0 + 1) * kN0, kK0>{}));
}); });
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads __builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
......
...@@ -291,6 +291,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -291,6 +291,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>; using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
{
if constexpr(KLoadOnce)
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
return k0_loops;
}
else
return 1;
}
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumVLdsBuffers() CK_TILE_DEVICE static constexpr auto GetNumVLdsBuffers()
{ {
...@@ -317,8 +332,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -317,8 +332,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
...@@ -382,6 +396,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -382,6 +396,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return WG::WarpGemmAttribute::Impl::kCM1PerLane; return WG::WarpGemmAttribute::Impl::kCM1PerLane;
} }
/*
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{ {
...@@ -400,12 +415,36 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -400,12 +415,36 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
k_lds_block_desc_0, k_lds_block_desc_0,
make_tuple( make_tuple(
make_pass_through_transform(number<kNPerBlock>{}), make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))), make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
make_tuple(sequence<1>{}, sequence<0, 2>{}), number<kKPack>{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{},
make_tuple(sequence<0>{}, sequence<1>{})); sequence<1>{}));
return k_lds_block_desc; return k_lds_block_desc;
} }
*/
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t SingleKSize = [&]() {
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return SingleKSize;
}
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
...@@ -428,6 +467,78 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -428,6 +467,78 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return SingleVSize; return SingleVSize;
} }
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
/*
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetKSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
ake_tuple(
make_merge_transform(make_tuple(
number<NumKLdsBuffers>{}, number<kNPerBlock / NPerRow>{},
number<NPerRow>{})), make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{},
number<kKPack>{}))), make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
*/
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock*(kNPerBlock + 1)>{},
number<(kNPerBlock + 1) * kKPack>{},
number<kKPack>{},
number<1>{}),
number<8>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
...@@ -532,8 +643,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -532,8 +643,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment