Commit b75c9265 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Load K once in 2wave pipeline

parent ee44cf04
...@@ -35,6 +35,8 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -35,6 +35,8 @@ struct BlockFmhaPipelineQRKSVS2Wave
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = true;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
...@@ -149,22 +151,23 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -149,22 +151,23 @@ struct BlockFmhaPipelineQRKSVS2Wave
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>, std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0BlockLength == QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0BlockLength == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
"wrong!"); kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// K tile in LDS // K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>( auto k_lds = make_tensor_view<address_space_enum::lds>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>())); reinterpret_cast<KDataType*>(smem_ptr),
auto k_lds = make_tensor_view<address_space_enum::lds>( Policy::template MakeKLdsStoreBlockDescriptor<Problem>());
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>()); auto k_lds_window_for_store =
auto k_lds_window = make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0BlockLength>{}), {0, 0});
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
...@@ -264,7 +267,7 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -264,7 +267,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops); static_assert(1 <= k1_loops);
do do
{ {
...@@ -278,11 +281,12 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -278,11 +281,12 @@ struct BlockFmhaPipelineQRKSVS2Wave
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
{ {
move_tile_window(k_dram_window, {0, kK0}); clear_tile(s_acc);
clear_tile(s_acc); // initialize C store_tile(k_lds_window_for_store,
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
} }
auto k_lds_window_for_load =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -296,44 +300,19 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -296,44 +300,19 @@ struct BlockFmhaPipelineQRKSVS2Wave
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
} }
if constexpr(k0_loops > 2)
{ {
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds();
block_sync_lds(); static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc, gemm_0(s_acc,
get_slice_tile(q_tile, get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{}, sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}), sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window); k_lds_window_for_load);
block_sync_lds(); move_tile_window(k_lds_window_for_load, {0, kK0});
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
......
...@@ -9,11 +9,105 @@ ...@@ -9,11 +9,105 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
using BlockFmhaPipelineQRKSVS2WaveDefaultPolicy = struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false, /* AsyncCopyK = */ false,
/* AsyncCopyV = */ false, /* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1, /* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr index_t K1 = 16 / sizeof(KDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
return MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size();
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
// TODO: assume Q is in register
// TODO: assume K/V has same data type
constexpr index_t single_smem_size =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
return QXPolicy::template GetSmemSizeQ<Problem>() +
single_smem_size * max(NumPrefetchK, NumPrefetchV);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
}
};
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment