Commit e6b69a63 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Load V once in 2wave pipeline (only support vlayout=c)

parent 9a771b0b
...@@ -36,6 +36,7 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -36,6 +36,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
static_assert(kQLoadOnce == Policy::QLoadOnce); static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = true; static constexpr bool kKLoadOnce = true;
static constexpr bool kVLoadOnce = true;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -157,7 +158,7 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -157,7 +158,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0BlockLength == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0BlockLength == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
...@@ -172,9 +173,9 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -172,9 +173,9 @@ struct BlockFmhaPipelineQRKSVS2Wave
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>()); Policy::template MakeVLdsStoreBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window( auto v_lds_window =
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0}); make_tile_window(v_lds, make_tuple(number<kN1>{}, number<kN0>{}), {0, 0});
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
...@@ -312,8 +313,6 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -312,8 +313,6 @@ struct BlockFmhaPipelineQRKSVS2Wave
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -471,64 +470,42 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -471,64 +470,42 @@ struct BlockFmhaPipelineQRKSVS2Wave
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
} }
block_sync_lds(); const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
const auto v = load_tile(v_dram_window); // load next v
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch); shuffle_tile(v_shuffle_tmp, v);
store_tile( store_tile(v_lds_window,
v_lds_window, tile_elementwise_in(v_element_func,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch v_shuffle_tmp)); // store the prefetch
} }
else else
{ {
store_tile(v_lds_window, store_tile(v_lds_window, tile_elementwise_in(v_element_func, v)); // store next v
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kN0});
const auto p = auto v_lds_window_for_load =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); make_tile_window(v_lds, make_tuple(number<kN1>{}, number<kK1>{}), {0, 0});
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { block_sync_lds();
const auto v = load_tile(v_dram_window); // load next v static_for<0, k1_loops, 1>{}([&](auto i_k1) {
block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
get_slice_tile( get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}), p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window); v_lds_window_for_load);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) move_tile_window(v_lds_window_for_load, {0, kK1});
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
move_tile_window(v_dram_window, {0, kK1});
}); });
} }
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
} while(++i_total_loops < num_total_loop); } while(++i_total_loops < num_total_loop);
// store lse // store lse
......
...@@ -40,6 +40,73 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -40,6 +40,73 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// TODO: this is used for non async copy desc. unify in the future // TODO: this is used for non async copy desc. unify in the future
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
...@@ -65,6 +132,47 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -65,6 +132,47 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc; return k_lds_block_desc;
} }
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumPrefetchV>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumPrefetchV>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{ {
...@@ -81,7 +189,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -81,7 +189,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
static_assert(PixelsPerRow % kKPack == 0); static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(kNPerBlock % NPerRow == 0); static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0); static_assert(kKPerBlock % kKPack == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment