Commit 3ee41b40 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Re-implement qr_ks_vs_async pipeline by using kLoadOnce

parent c0b90f13
...@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel ...@@ -1064,14 +1064,14 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
else else
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<false, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() { const auto k_dram = [&]() {
...@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel ...@@ -1082,10 +1082,20 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
return pad_tensor_view( if constexpr(FmhaPipeline::kKLoadOnce)
k_dram_naive, {
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), return pad_tensor_view(
sequence<kPadSeqLenK, kPadHeadDimQ>{}); k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
}
}(); }();
const auto v_dram = [&]() { const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel ...@@ -1107,7 +1117,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_transposed, v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, false>{});
} }
else else
{ {
...@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel ...@@ -1121,7 +1131,7 @@ struct FmhaFwdKernel
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<false, kPadSeqLenK>{});
} }
}(); }();
...@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel ...@@ -1137,7 +1147,15 @@ struct FmhaFwdKernel
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0}); k_dram,
[&]() {
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0});
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
......
...@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS ...@@ -316,11 +316,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// load Q from LDS // load Q from LDS
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load = make_tile_window( auto q_lds_window_for_load =
q_lds, make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}, {0, 0},
Policy::template MakeQRegTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQRegTileDistribution<Problem>());
block_sync_lds(); block_sync_lds();
auto q = load_tile(q_lds_window_for_load); auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -13,15 +13,11 @@ namespace ck_tile { ...@@ -13,15 +13,11 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
template <typename Problem> template <typename Problem>
...@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy ...@@ -76,10 +72,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
template <typename Problem, typename BlockGemm> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{ {
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>(); return BasePolicy::template MakeQDramTileDistribution<Problem>();
} }
template <typename Problem> template <typename Problem>
......
...@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -180,11 +180,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window( auto q_dram_window =
q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -11,9 +11,7 @@ namespace ck_tile { ...@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
...@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS ...@@ -35,6 +35,9 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS ...@@ -178,11 +181,11 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window( auto q_dram_window =
q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem>());
auto q = load_tile(q_dram_window); auto q = load_tile(q_dram_window);
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
...@@ -12,7 +11,6 @@ ...@@ -12,7 +11,6 @@
namespace ck_tile { namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsync struct BlockFmhaPipelineQRKSVSAsync
{ {
...@@ -36,6 +34,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -36,6 +34,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = true;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -47,68 +48,51 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -47,68 +48,51 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>(); static constexpr index_t kAlignmentQ =
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() { static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>(); return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}(); }();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kQKHeaddim <= 32) if constexpr(kQKHeaddim <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && return 2;
FmhaMask::IsMasking)
return 1;
else
return 2;
} }
else if constexpr(kQKHeaddim <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 2;
return 2;
else
return 3;
} }
else if constexpr(kQKHeaddim <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 1;
} }
else if constexpr(kQKHeaddim <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
...@@ -142,10 +126,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -142,10 +126,10 @@ struct BlockFmhaPipelineQRKSVSAsync
typename OAccElementFunction, typename OAccElementFunction,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& /*k_element_func*/, const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
...@@ -170,50 +154,28 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -170,50 +154,28 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
// K tile in LDS // K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr); KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
auto k_lds_store = generate_tuple( static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
[&](auto i_buf) { auto k_lds = make_tensor_view<address_space_enum::lds>(
return make_tile_window( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
make_tensor_view<address_space_enum::lds>( auto k_lds_window =
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)), make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kSubQKHeaddim>{}), {0, 0});
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto k_lds_load = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0});
},
number<Policy::NumPrefetchK>{});
#else
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
#endif
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeK<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>()); Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window( auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0}); v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
...@@ -222,21 +184,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -222,21 +184,13 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window( auto q_dram_window =
q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem>());
q_dram_window.init_raw();
auto q = load_tile(q_dram_window);
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{}; auto s_acc = SaccBlockTileType{};
...@@ -262,7 +216,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -262,7 +216,6 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile(m, -numeric<SMPLComputeDataType>::infinity()); set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l); clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin(); const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -283,13 +236,11 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -283,13 +236,11 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc; return o_acc;
} }
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
} }
auto k_dram_block_window = auto k_dram_block_window =
...@@ -303,16 +254,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -303,16 +254,7 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw(); auto k_tile = load_tile(k_dram_window);
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = auto bias_dram_window =
...@@ -330,81 +272,58 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -330,81 +272,58 @@ struct BlockFmhaPipelineQRKSVSAsync
{0, seqlen_k_start}, // TODO: hdim split? {0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile auto q_tile = tile_elementwise_in(q_element_func, q);
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops); static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops); static_assert(1 <= k1_loops);
// main loop
do do
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
store_tile(k_lds_window, k_tile);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
if(i_total_loops < num_total_loop - 1)
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { move_tile_window(k_dram_window, {kN0, 0});
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}), k_tile = load_tile(k_dram_window);
k_dram_window, }
number<-1>{},
k_oob_ck, __builtin_amdgcn_sched_barrier(0);
k_pre_np);
if constexpr(i_k0 < k0_loops - 1) // for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops
move_tile_window(k_dram_window, {0, kK0}); if constexpr(kQKHeaddim == kSubQKHeaddim)
{
async_load_fence(k_dram_window.get_num_of_access()); gemm_0(s_acc, q, k_lds_window);
__builtin_amdgcn_s_barrier(); }
__builtin_amdgcn_sched_barrier(0); else
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc, gemm_0(s_acc,
get_slice_tile( get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM get_slice_tile(k_lds_window,
k_lds_load[number<LdsSeq.at(number<i_k0>{})>{}]); sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
#else
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
#endif
}); });
} }
// TODO: this to fix a bug when loop smaller than 2, __builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<k0_loops - 1>{})>{}]);
#else auto v_buf = load_tile(v_dram_window); // prefetch load v tile
get_slice_tile(
k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
#endif
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
...@@ -457,7 +376,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -457,7 +376,6 @@ struct BlockFmhaPipelineQRKSVSAsync
k_origin.at(number<0>{}), k_origin.at(number<0>{}),
number<kM0>{}, number<kM0>{},
number<kN0>{}); number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if( set_tile_if(
...@@ -484,7 +402,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -484,7 +402,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>( auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j} s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F); __builtin_amdgcn_sched_barrier(0);
// store & prefetch next v, after the max reduction // store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -493,9 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -493,9 +411,7 @@ struct BlockFmhaPipelineQRKSVSAsync
shuffle_tile(v_shuffle_tmp, v_buf); shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = auto v_lds_window_tmp =
get_slice_tile(v_lds_window, get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile( store_tile(
v_lds_window_tmp, v_lds_window_tmp,
...@@ -504,26 +420,25 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -504,26 +420,25 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
{ {
auto v_lds_window_tmp = auto v_lds_window_tmp =
get_slice_tile(v_lds_window, get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
} }
move_tile_window(v_dram_window, {0, kK1});
if constexpr(k1_loops > 1) __builtin_amdgcn_sched_barrier(0);
if constexpr(NumVLdsBuffers > 1)
{ {
move_tile_window( v_buf = load_tile(v_dram_window); // load next v_buf
v_dram_window, move_tile_window(v_dram_window, {0, kK1});
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) { static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need /// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem /// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
{ {
...@@ -583,7 +498,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -583,7 +498,7 @@ struct BlockFmhaPipelineQRKSVSAsync
} }
}(); }();
#else #else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif #endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
...@@ -597,97 +512,120 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -597,97 +512,120 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
auto randval_ptr = auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>(); Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
} }
const auto p = [&]() { const auto p =
if constexpr(std::is_same_v<PDataType, fp16_t>) cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
}();
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(NumVLdsBuffers == 1)
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) {
{ static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
v_buf = load_tile( v_buf = load_tile(v_dram_window); // load next v_buf
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf block_sync_lds();
} gemm_1(
block_sync_lds(); o_acc,
gemm_1(o_acc, get_slice_tile(
get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}), get_slice_tile(v_lds_window,
get_slice_tile( sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
v_lds_window, sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{})); if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) {
{ auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); shuffle_tile(v_shuffle_tmp, v_buf);
shuffle_tile(v_shuffle_tmp, v_buf); auto v_lds_window_tmp = get_slice_tile(
auto v_lds_window_tmp = get_slice_tile( v_lds_window,
v_lds_window, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{}, sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{}); block_sync_lds();
store_tile(v_lds_window_tmp, store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch v_shuffle_tmp)); // store the prefetch
} }
else else
{ {
auto v_lds_window_tmp = get_slice_tile( auto v_lds_window_tmp = get_slice_tile(
v_lds_window, v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{}, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{}); sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, block_sync_lds();
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf store_tile(
} v_lds_window_tmp,
if constexpr(i_k1 < k1_loops - 1) tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
} }
i_total_loops++; else
if(i_total_loops < num_total_loop) {
{ static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
// move K tile windows if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1)
move_tile_window(k_dram_block_window, {kN0, 0}); v_buf = load_tile(v_dram_window); // load next v_buf
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
block_sync_lds();
if constexpr(k1_loops >= 2 && gemm_1(
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{})) o_acc,
__builtin_amdgcn_s_barrier(); get_slice_tile(
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
k_dram_window, get_slice_tile(v_lds_window,
number<-1>{}, sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
k_oob_ck, sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
} }
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(
o_acc, o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}), get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile( get_slice_tile(v_lds_window,
v_lds_window, sequence<((k1_loops - 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{}, sequence<(((k1_loops - 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}));
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{})); block_sync_lds();
} }
} while(i_total_loops < num_total_loop); } while(++i_total_loops < num_total_loop);
// store lse // store lse
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
...@@ -701,11 +639,11 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -701,11 +639,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
} }
else else
{ {
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
} }
#else #else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]); lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
......
...@@ -8,12 +8,80 @@ ...@@ -8,12 +8,80 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, /* AsyncCopy = */ true,
/* AsyncCopyK = */ true, /* NumPrefetchV = */ 2>
/* AsyncCopyV = */ false, {
/* NumPrefetchK = */ 3, template <typename Problem>
/* NumPrefetchV = */ 3>; CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
Problem::BlockFmhaShape::kSubQKHeaddim)
? Problem::BlockFmhaShape::kSubQKHeaddim
: Problem::BlockFmhaShape::kK0;
using GemmProblem = BlockGemmProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0, BlockGemmK>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -8,12 +8,9 @@ ...@@ -8,12 +8,9 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaPipelineQRKSVSDefaultPolicy = using BlockFmhaPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>;
} // namespace ck_tile } // namespace ck_tile
...@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS ...@@ -34,6 +34,9 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kQLoadOnce = false; static constexpr bool kQLoadOnce = false;
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = false;
static_assert(kKLoadOnce == Policy::KLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -94,6 +97,8 @@ struct BlockFmhaPipelineQSKSVS
{ {
return 1; return 1;
} }
else
return 1;
} }
}(); }();
......
...@@ -11,9 +11,7 @@ namespace ck_tile { ...@@ -11,9 +11,7 @@ namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
struct BlockFmhaPipelineQSKSVSDefaultPolicy struct BlockFmhaPipelineQSKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false, /* AsyncCopy = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1> /* NumPrefetchV = */ 1>
{ {
template <typename Problem> template <typename Problem>
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
namespace ck_tile { namespace ck_tile {
template <bool QLoadOnce_> template <bool QLoadOnce_>
...@@ -50,9 +47,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -50,9 +47,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
} }
template <typename Problem, typename BlockGemm> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution< return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kSubQKHeaddim>(); Problem::BlockFmhaShape::kSubQKHeaddim>();
...@@ -277,72 +276,32 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -277,72 +276,32 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
} }
}; };
// This pipeline is qkv all located in LDS template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchV_>
template <bool QLoadOnce_,
bool AsyncCopyK_,
bool AsyncCopyV_,
index_t NumPrefetchK_,
index_t NumPrefetchV_>
struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLoadOnce_> struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>
{ {
static constexpr bool AsyncCopyK = AsyncCopyK_; static constexpr index_t NumPrefetchV = NumPrefetchV_;
static constexpr bool AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet
static constexpr index_t NumPrefetchK = NumPrefetchK_;
static constexpr index_t NumPrefetchV = NumPrefetchK_;
using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>;
template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
struct LdsBufferSequence
{
static constexpr auto Make()
{
return transform_sequences(
[&](auto i) {
if(i < k_loops_)
return i % k_prefetches_;
return (i - k_loops_) % v_prefetches_;
},
typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{});
};
using type = remove_cvref_t<decltype(Make())>; // 1) When Async == true, we preload whole K-tile for next iteration using single LDS buffer,
}; // and preload V-slice for next unroll using multiple LDS buffers
// clang-format off // 2) When Async == false, we preload K-slice for next unroll using single LDS buffer, and
template<> struct // preload V-slice for next unroll using single LDS buffer
LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; static constexpr bool AsyncCopy = AsyncCopy_;
template<> struct static constexpr bool KLoadOnce = AsyncCopy;
LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct using QXPolicy = BlockFmhaPipelineQXCustomPolicy<QLoadOnce_>;
LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence() CK_TILE_DEVICE static constexpr auto GetNumVLdsBuffers()
{ {
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t kN0 = BlockFmhaShape::kN0; constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0; constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{}; return min(NumPrefetchV, k1_loops);
} }
template <typename Problem> template <typename Problem>
...@@ -356,15 +315,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -356,15 +315,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; constexpr index_t kBlockSize = Problem::kBlockSize;
if constexpr(AsyncCopyK) constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
{ constexpr index_t kKPerBlock =
return 4 / sizeof(KDataType); KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
}
else constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
{
return 16 / sizeof(KDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
} static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
} }
template <typename Problem> template <typename Problem>
...@@ -382,17 +342,17 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -382,17 +342,17 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad = constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType))); min(ElemPerThread, static_cast<index_t>(16 / sizeof(VDataType)));
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad ? kMaxVecLoad
: (total_pixels / kMinVecLoad); : (ElemPerThread / kMinVecLoad);
return kVecLoad; return kVecLoad;
} }
...@@ -422,61 +382,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -422,61 +382,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return WG::WarpGemmAttribute::Impl::kCM1PerLane; return WG::WarpGemmAttribute::Impl::kCM1PerLane;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
if constexpr(!AsyncCopyK)
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
}
else
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(warpSize * KVector >= kKPerBlock &&
warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = warpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (warpSize * KVector + kPad);
}
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock =
constexpr index_t kKPack = GetSmemKPackK<Problem>(); KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}), make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
...@@ -495,164 +407,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -495,164 +407,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return k_lds_block_desc; return k_lds_block_desc;
} }
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
warpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<warpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
#else
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{ {
// K is always k-major, we use async-copy to load into LDS constexpr index_t SingleVSize = [&]() {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t warpSize = ck_tile::get_warp_size(); static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t BufferSize =
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumPrefetchK>{}, // num_buffers
number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor( return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
k_lds_block_desc_0, }();
make_tuple(
make_merge_transform(make_tuple(number<NumPrefetchK>{},
number<NumIssues>{},
number<LaneGroups>{},
number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc; return SingleVSize;
} }
#endif
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
...@@ -669,13 +443,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -669,13 +443,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
static_assert(kNPerBlock % NPerRow == 0); static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0); static_assert(kKPerBlock % kKPack == 0);
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumPrefetchV>{}, make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{}, number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{}, number<kNPerBlock / NPerRow>{},
number<NPerRow>{}, number<NPerRow>{},
number<kKPack>{}), number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{}, make_tuple(number<GetVSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{}, number<PixelsPerRow + kKPack>{},
number<kKPack>{}, number<kKPack>{},
...@@ -687,7 +463,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -687,7 +463,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
v_lds_block_desc_0, v_lds_block_desc_0,
make_tuple( make_tuple(
make_merge_transform(make_tuple( make_merge_transform(make_tuple(
number<NumPrefetchV>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})), number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))), make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -696,28 +472,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -696,28 +472,26 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{ {
// TODO: assume Q is in register return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
// TODO: assume K/V has same data type sizeof(typename Problem::KDataType);
constexpr index_t single_smem_size = }
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
return QXPolicy::template GetSmemSizeQ<Problem>() + template <typename Problem>
single_smem_size * max(NumPrefetchK, NumPrefetchV); CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
if constexpr(AsyncCopyK) // assume Q can reuse the shared memory with K or V
{ return max(QXPolicy::template GetSmemSizeQ<Problem>(),
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0); GetSmemSizeK<Problem>() + GetSmemSizeV<Problem>()) +
} GetSmemSizeDropout<Problem>(0);
else
{
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
}
} }
// this method is only available when Problem::kHasDropout is present // this method is only available when Problem::kHasDropout is present
...@@ -754,58 +528,33 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -754,58 +528,33 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
if constexpr(!AsyncCopyK) using KDataType = remove_cvref_t<typename Problem::KDataType>;
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(KDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( constexpr index_t kBlockSize = Problem::kBlockSize;
tile_distribution_encoding<sequence<1>, constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t kKPerBlock =
tuple<sequence<1>, sequence<1, 2>>, KLoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave static_assert(0 < ElemPerThread);
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues; constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t N1 = LaneGroups; constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t N2 = NumWarps; constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t K0 = LanesPerK; constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t K1 = KVector; constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
tuple<sequence<1>, sequence<1, 2>>, sequence<KThreads, KPerThread>>,
tuple<sequence<2>, sequence<1, 0>>, tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2>, tuple<sequence<1>, sequence<2, 0>>,
sequence<0, 1>>{}); sequence<1, 2>,
} sequence<0, 1>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -822,9 +571,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -822,9 +571,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N1 = GetAlignmentV<Problem>(); constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? static_assert(ElemPerThread % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1; constexpr index_t K3 = ElemPerThread / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>(); constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0); static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
...@@ -893,11 +642,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -893,11 +642,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>(); constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? static_assert(ElemPerThread % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1; constexpr index_t K3 = ElemPerThread / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>(); constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0); static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment