Commit 3ee41b40 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Re-implement qr_ks_vs_async pipeline by using kLoadOnce

parent c0b90f13
...@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel ...@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() { const auto k_dram = [&]() {
...@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel ...@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kKLoadOnce)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
}
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel ...@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, false>{});
} }
else else
{ {
...@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel ...@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<false, kPadSeqLenK>{});
} }
}(); }();
...@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel ...@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0}); k_dram,
[&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
......
...@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS ...@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// load Q from LDS // load Q from LDS
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load = make_tile_window( auto q_lds_window_for_load =
q_lds, make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}, {0, 0},
Policy::template MakeQRegTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQRegTileDistribution<Problem>());
block_sync_lds(); block_sync_lds();
auto q = load_tile(q_lds_window_for_load); auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -13,15 +13,11 @@ namespace ck_tile { ...@@ -13,15 +13,11 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
template <typename Problem> template <typename Problem>
...@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy ...@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
template <typename Problem, typename BlockGemm> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{ {
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>(); return BasePolicy::template MakeQDramTileDistribution<Problem>();
} }
template <typename Problem> template <typename Problem>
......
...@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window( auto q_dram_window =
q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -11,9 +11,7 @@ namespace ck_tile { ...@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
...@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS ...@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS ...@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window( auto q_dram_window =
q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -8,12 +8,80 @@ ...@@ -8,12 +8,80 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, /* AsyncCopy = */ true,
/* AsyncCopyK = */ true, /* NumPrefetchV = */ 2>
/* AsyncCopyV = */ false, {
/* NumPrefetchK = */ 3, template <typename Problem>
/* NumPrefetchV = */ 3>; CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
Problem::BlockFmhaShape::kSubQKHeaddim)
? Problem::BlockFmhaShape::kSubQKHeaddim
: Problem::BlockFmhaShape::kK0;
using GemmProblem = BlockGemmProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0, BlockGemmK>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -8,12 +8,9 @@ ...@@ -8,12 +8,9 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaPipelineQRKSVSDefaultPolicy = using BlockFmhaPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
} // namespace ck_tile } // namespace ck_tile
...@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS ...@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kQLoadOnce = false; static constexpr bool kQLoadOnce = false;
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
{ {
return 1; return 1;
} }
else
return 1;
} }
}(); }();
......
...@@ -11,9 +11,7 @@ namespace ck_tile { ...@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQSKSVSDefaultPolicy struct BlockFmhaPipelineQSKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment