Unverified Commit 2a30cfdd authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents 9533a172 78195ccc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "shb";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
......@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kNumWarps = Problem::kNumWarps;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV;
......@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
......@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
......@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_dram_window =
auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
auto o_acc_4_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
o_acc_4_dist);
// shape=[4 * KM0, kN1]
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
clear_tile(o_acc_4);
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// each warp handles a [KM0, kN1] tile
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
{
auto o_tile = load_tile(o_acc_dram_window);
auto o_tile = load_tile(o_acc_4_dram_window);
const index_t i_split = split_start + get_warp_id();
const index_t row_start = kM0 * get_warp_id();
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
o_acc_4.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
});
});
}
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
{
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0});
}();
store_tile(o_acc_4_lds_window, o_acc_4);
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
}();
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, kNumWarps, 1>{}([&](auto) {
auto o_acc_in = load_tile(o_acc_4_lds_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_in(i_j_idx);
});
});
}
move_tile_window(o_acc_4_lds_window, {kM0, 0});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
......@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
......@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
identity{},
num_splits,
seqlen_q,
smem_ptr);
}
};
......
......@@ -10,23 +10,38 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <index_t BlockSize, index_t M, index_t N, typename DataType>
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile()
{
static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
if constexpr(0 < ElemPerThread)
{
return NumWarps;
}
else
{ // try dividing tile by smaller # of warps
return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
}
}
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile<NumWarps, M, N, DataType>();
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
return NPerThread;
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
return min(MaxNPerThread, ElemPerThread);
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
return GetVectorSizeForTile<Problem::kBlockSize,
return GetVectorSizeForTile<Problem::kNumWarps,
Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
......@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
{
return sizeof(typename Problem::OaccDataType) *
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
}
// shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t MaxNumWarps =
GetMaxNumWarpsForTile<Problem::kNumWarps, kNPerBlock, kMPerBlock, LSEDataType>();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<MaxNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MPerThread, MaxNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
......@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<NPack>{},
number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
......@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<NPack>{},
number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
......@@ -152,41 +177,95 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_t_lds_block_desc;
}
// 3d + padding, shape=[4 * kM0, kN1]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = 4 * Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t NPack =
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
o_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return o_acc_t_lds_block_desc;
}
// shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NThreads = 4;
constexpr index_t MaxNThreads = 8;
constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = 1;
constexpr index_t MThreads = kMPerBlock / MPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MaxNumWarps, MThreadPerWarp, MPerThread>,
sequence<NThreads, NPerThread>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = 1; // compose encoding base on 1 warp
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
tuple<sequence<0, 2>, sequence<3, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_,
typename Policy_ = BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_nwarp_sshuffle";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
// K tile in LDS
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto s_read_lds_window =
make_tile_window(s_lds,
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeSRegTileDistribution<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
// load Q here, will store Q into LDS to maximize throughput
auto origin_q = load_tile(q_dram_window);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
// init M, L
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
}
}
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_store = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
store_tile(q_lds_window_for_store, origin_q);
__builtin_amdgcn_sched_barrier(0);
// load Q from LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load = make_tile_window(
q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem, decltype(gemm_0)>());
block_sync_lds();
auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0);
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
auto k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load the first tile of the first iteration and store to LDS
auto k_block_tile = load_tile(k_dram_window);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
// position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile = load_tile(k_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
store_tile(s_write_lds_window, s);
block_sync_lds();
auto s_new = load_tile(s_read_lds_window);
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s_new,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s_new.get_tile_distribution()); // Pcompute{j}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
v_lds_window);
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);
if constexpr(kStoreLSE)
{
// store lse acc
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_lengths,
k_page_block_navigator,
identity{},
v_dram_block_window_lengths,
v_page_block_navigator,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
kv_l2p_offset,
smem_ptr);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
return static_cast<index_t>(16 / sizeof(SDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kNPack = GetSmemNPackS<Problem>();
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
number<kNPack>{},
number<1>{});
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
s_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return s_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1, "Check failed!");
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = kKPerBlock / (K2 * K3);
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr auto s2_block_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
}
};
} // namespace ck_tile
......@@ -103,31 +103,47 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
// extract tile size attributes to remove dependency on traits
template <typename OaccDataType_, ck_tile::index_t kN1_>
struct BlockFmhaSplitKVCombinePipelineTileSizes
{
static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
static constexpr index_t kN1 = kN1_;
static constexpr index_t NThreads = kN1 / MaxVectorSize;
static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
};
template <typename LSEDataType_,
typename OaccDataType_,
typename ODataType_,
index_t HeadDimV_,
index_t kM0_,
index_t kN1_,
bool kIsGroupMode_,
ck_tile::index_t kN1_,
typename Traits_>
struct BlockFmhaSplitKVCombinePipelineProblem
: BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
{
using BaseType = BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static_assert(std::is_same_v<LSEDataType, OaccDataType>);
static constexpr index_t kHeadDimV = HeadDimV_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN1 = kN1_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using BaseType::kM0;
using BaseType::kN1;
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -136,6 +152,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
static_assert(8 <= kMaxSplits);
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
(kM0 * kMaxSplits) % get_warp_size() == 0);
};
template <typename QDataType_,
......
......@@ -5,14 +5,14 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct [[deprecated]] BlockFmhaPipelineQSKSVS
struct BlockFmhaPipelineQSKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
......@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
......@@ -81,20 +99,18 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return Policy::template GetSmemSizeQ<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
......@@ -114,6 +130,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& /* unused_randval_dram_block_window_tmp */,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
......@@ -122,7 +139,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
void* smem_ptr,
DropoutType& /* unused_dropout */) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......@@ -222,11 +240,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
{seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -305,7 +323,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window);
......@@ -318,6 +335,10 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
gemm_0(s_acc, q_lds_window, k_lds_window);
}
__builtin_amdgcn_sched_barrier(0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -439,6 +460,12 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
......@@ -486,9 +513,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
......@@ -583,6 +607,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
......@@ -590,11 +615,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
......@@ -604,6 +631,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
......@@ -612,7 +640,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
mask,
position_encoding,
scale_s,
smem_ptr);
smem_ptr,
dropout);
}
};
......
......@@ -9,11 +9,33 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaPipelineQSKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
struct BlockFmhaPipelineQSKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
} // namespace ck_tile
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>();
}
};
} // namespace ck_tile
......@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kSubQKHeaddim>();
}
template <typename Problem>
......@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32);
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
......@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
......@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
......@@ -152,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
}
};
/// NOTICE: we no-longer use this policy.
template <>
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{
static constexpr bool QLoadOnce = false;
......@@ -174,8 +146,16 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
......@@ -184,19 +164,25 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
......@@ -243,18 +229,31 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
......@@ -43,8 +43,6 @@ struct TileFmhaShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
......
......@@ -43,6 +43,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasUnevenSplits_,
bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits
{
......@@ -57,6 +58,7 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
......@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs
const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t intermediate_size; // n / TP, for Gate/UP/Down
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
......@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
......@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t intermediate_size; // n / TP, for Gate/Up/Down
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
......@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
......@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id &= 0xffffff;
#endif
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
......
......@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
......@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
......@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
int smem_rows = [&](){
index_t target_occupancy_ = 2;
constexpr index_t total_ = 65536 / sizeof(int);
constexpr index_t sub_unroll = 8;
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
throw std::runtime_error("too many num_experts, can't allocate smem");
target_occupancy_ = 1;
}
int r = total_ / target_occupancy_ / smem_cols;
// round to sub_unroll multipl
int r_for_sub_token = r - cumsum_bufs;
r_for_sub_token = min(r_for_sub_token, num_tokens_);
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = max(r_for_sub_token, 1);
if(r_for_sub_token > 1)
{
int r_unroll_ = r_for_sub_token / sub_unroll;
// round to 1x/2x/4x/8x number of sub_unroll
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
int mask_ = (1 << (31 - clz_)) - 1;
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
mask_ = ~mask_;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
}
// final check
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
throw std::runtime_error("can't run this kernel, request LDS over size");
}
return r_for_sub_token + cumsum_bufs;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return ck_tile::make_tuple(smem_rows, smem_cols);
}
struct MoeSortingHostArgs
{
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
......@@ -101,6 +203,7 @@ struct MoeSortingKernel
{
const void* p_topk_ids;
const void* p_weights;
const void* p_local_expert_mask;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
......@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t moe_buf_bytes;
index_t tokens_per_thread;
index_t smem_rows;
mdiv unit_size_mdiv;
mdiv topk_mdiv;
mdiv expert_mdiv;
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
......@@ -123,14 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
#if MOE_SORTING_USE_EX_KERNEL
(void)h;
return dim3(256);
#else
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
#endif
}
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
#if MOE_SORTING_USE_EX_KERNEL
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
return smem_rows * smem_cols * sizeof(int);
#else
const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
#endif
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
......@@ -138,6 +255,7 @@ struct MoeSortingKernel
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
......@@ -151,9 +269,150 @@ struct MoeSortingKernel
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
return r_;
}();
k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return k;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const
{
// wave_size must be power of 2
constexpr int row_mask = 0xf;
constexpr int bank_mask = 0xf;
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
if constexpr(wave_size > 1)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x111,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:1
}
if constexpr(wave_size > 2)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x112,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:2
}
if constexpr(wave_size > 4)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x114,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:4
}
if constexpr(wave_size == 8) {
// wave-size=8 need one extra shift
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t xxx =__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask,
bound_ctrl)); // row_newbcast:7
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#endif
}
if constexpr(wave_size > 8)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
}
if constexpr(wave_size > 16)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
if constexpr(wave_size > 32)
{
// lane-id 48...63->31
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
else if constexpr(wave_size_ == 8) return 3;
else if constexpr(wave_size_ == 16) return 4;
else if constexpr(wave_size_ == 32) return 5;
else if constexpr(wave_size_ == 64) return 6;
else return 0;
}();
// clang-format on
T v_local = local;
#pragma unroll reduce_stage
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{
return row * total_col + col;
......@@ -187,36 +446,98 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
}
__syncthreads();
#if 1
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
index_t local_c[8];
index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
{
local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
local_c[0] += prev_c;
local_c[1] += local_c[0];
local_c[2] += local_c[1];
local_c[3] += local_c[2];
local_c[4] += local_c[3];
local_c[5] += local_c[4];
local_c[6] += local_c[5];
local_c[7] += local_c[6];
prev_c = local_c[7];
tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{
if(tid < num_experts)
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i += 8)
{
index_t local_c[8];
#pragma unroll
for(int j = 0; j < 8; j++)
{
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
}
#pragma unroll
for(int j = 0; j < 8; j++)
{
wave_cumsum<int, 64>(local_c[j]);
}
#pragma unroll
for(int j = 0; j < 8; j++)
{
tokens_cnts[calc_index(num_experts, i, tid)] +=
tokens_cnts[calc_index(num_experts, i - 1, tid)];
tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
}
}
}
#endif
// __syncthreads();
__syncthreads();
if constexpr(Problem::ExpertTile == 0)
{
if(tid == 0)
{
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
{
auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
......@@ -225,10 +546,34 @@ struct MoeSortingKernel
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1))
{
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum;
}
if(tid < num_experts)
{
cumsum[tid + 1] = local_cumsum;
}
}
__syncthreads();
if(tid < num_experts)
{
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
}
......@@ -238,8 +583,8 @@ struct MoeSortingKernel
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
......@@ -248,15 +593,18 @@ struct MoeSortingKernel
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif
p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
}
if constexpr(Problem::ExpertTile == 0)
{
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts)
{
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
......@@ -269,6 +617,387 @@ struct MoeSortingKernel
}
}
}
else
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1];
if(eid < num_experts)
{
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset += experts_per_wave;
}
}
}
}
}
// only support index_t, and single pixel access
struct simple_smem_indexer
{
index_t* smem;
index_t row_stride;
// this is 2D
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_)
: smem(smem_), row_stride(row_stride_)
{
}
CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const
{
return smem[i_row * row_stride + i_col];
}
CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col)
{
return smem[i_row * row_stride + i_col];
}
// this is 1D or linear
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {}
CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; }
};
CK_TILE_DEVICE void
moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
const IndexType* __restrict__ local_expert_mask,
index_t* p_sorted_token_ids,
WeightType* p_sorted_weights,
index_t* p_sorted_expert_ids,
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
const index_t smem_rows,
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
const index_t topk = topk_mdiv.divisor;
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const index_t smem_cols = num_experts + 1;
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
smem_cols};
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
// NOTE: below for loop can't have barrier inside!!
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
if constexpr(Problem::SubTokenOneShot)
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
}
// counting
if(tid == 0)
{
smem_cumsum(0) = 0;
// smem_cumdup(0) = 0;
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
{
index_t local_c[Problem::SubTokenTile];
index_t cnt = 0;
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
{
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
if constexpr(Problem::SubTokenOneShot)
{
local_c[j] = local_c[j] != 0 ? 1 : 0;
}
}
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
}
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
}
}
if constexpr(Problem::LocalExpertMasking)
{
smem_cumdup(0) = 0;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
// reuse this buffer
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
}
}
__syncthreads();
{
if(wid == 0)
{
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int pre_cumsum_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(lid == 0 ? i_e_ : 0);
else
return 0; // not used
}();
int local_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(i_e_ + lid + 1);
else
return 0; // not used
}();
int padded_tokens_per_expert = [&]() {
int x_ = [&]() {
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return blocks_pers_expert * unit_size_mdiv.divisor;
}
else
{
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
}
}();
if constexpr(Problem::LocalExpertMasking)
{
return local_masking ? x_ : 0;
}
else
return x_;
}();
local_cumsum_ = padded_tokens_per_expert;
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
}
__syncthreads();
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
int e_end = smem_cumsum(i_e + 1);
int expert_id = [&]() {
if constexpr(Problem::LocalExpertMasking)
{
// local expert id from cumsum
return smem_cumdup(i_e);
}
else
return i_e;
}();
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[i_e] == 0)
continue;
}
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
if constexpr(!Problem::SubTokenOneShot)
{
// clear every time
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
// load again
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id_, curr_topk_id_;
topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
int curr_token_id = static_cast<int>(curr_token_id_);
int curr_topk_id = static_cast<int>(curr_topk_id_);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
}
}
__syncthreads();
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
{
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[eid] == 0)
continue;
}
int position = smem_cumsum(eid);
for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
i_sub_token += lane_group_sz)
{
auto x = smem_tokens(i_sub_token, eid);
int local_cnt_cache = x != 0 ? 1 : 0;
int local_cnt = local_cnt_cache;
wave_cumsum<int, lane_group_sz>(local_cnt);
if(x != 0)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[position + local_cnt - 1] =
MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
#else
p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
#endif
p_sorted_weights[position + local_cnt - 1] =
weights[(i_token + i_sub_token) * topk + x - 1];
}
int remote_cnt = __builtin_amdgcn_ds_bpermute(
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
position += remote_cnt;
}
smem_cumsum(eid) = position;
}
}
__syncthreads();
}
// add the skip number
for(int eid = tid; eid < num_experts; eid += block_size)
{
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
while(e_start < e_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
#else
p_sorted_token_ids[e_start] = tokens;
#endif
p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
e_start++;
}
}
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
......@@ -283,6 +1012,24 @@ struct MoeSortingKernel
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<const IndexType*>(kargs.p_local_expert_mask),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
static_cast<WeightType*>(kargs.p_sorted_weights),
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
kargs.smem_rows,
smem);
#else
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
......@@ -295,6 +1042,7 @@ struct MoeSortingKernel
kargs.unit_size_mdiv,
kargs.topk_mdiv,
smem);
#endif
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename IndexType_,
typename WeightType_,
index_t InternalLoadUnroll_,
index_t ExpertTile_ = 0>
struct MoeSortingProblem
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll =
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment