Commit 119dd2ac authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Let V reuse the LDS of K

parent 512eeecb
...@@ -175,12 +175,11 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -175,12 +175,11 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(NumKLdsBuffers >= 2); static_assert(NumKLdsBuffers >= 2);
static_assert(NumVLdsBuffers >= 2); static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(),
q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQRegTileDistribution<Problem>());
Policy::template MakeQRegTileDistribution<Problem>()); auto q = load_tile(q_dram_window);
auto q = load_tile(q_dram_window);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -193,8 +192,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -193,8 +192,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) + reinterpret_cast<VDataType*>(smem_ptr),
Policy::template GetSmemSizeK<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>()); Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window( auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0}); v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
...@@ -296,7 +294,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -296,7 +294,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
if(i_total_loops == 0) // executed by fist iteration if(i_total_loops == 0) // executed by fist iteration
{ {
if(i_total_loops < num_total_loop) if(num_total_loop > 1) // there are multiple iterations
{ {
auto k_lds_window_tmp = auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}); get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
...@@ -341,7 +339,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -341,7 +339,7 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<kM0, k0_loops * kK0>{}), sequence<kM0, k0_loops * kK0>{}),
k_lds_window_tmp); k_lds_window_tmp);
} }
else else // there is only single iteration
{ {
auto k_lds_window_tmp = auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}); get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
......
...@@ -13,174 +13,13 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -13,174 +13,13 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */ true, /* AsyncCopy = */ true,
/* NumPrefetchV = */ 2> /* NumPrefetchV = */ 2>
{ {
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
/*
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
Problem::BlockFmhaShape::kSubQKHeaddim)
? Problem::BlockFmhaShape::kSubQKHeaddim
: Problem::BlockFmhaShape::kK0;
using GemmProblem = BlockGemmProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
BlockGemmK>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename
Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// assume Q can reuse the shared memory with K or V // assume V can reuse the shared memory by K
// assume Dropout can reuse the shared memory with V // assume Dropout can reuse the shared memory by V
return max(GetSmemSizeQ<Problem>(), return max(GetSmemSizeK<Problem>(),
GetSmemSizeK<Problem>() + max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment