Commit c881136b authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into ck_tile/support-vllm-kcache-layout

parents c5e8e14f 4e076909
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
return static_cast<index_t>(16 / sizeof(SDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kNPack = GetSmemNPackS<Problem>();
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
number<kNPack>{},
number<1>{});
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
s_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return s_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1, "Check failed!");
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = kKPerBlock / (K2 * K3);
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr auto s2_block_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
}
};
} // namespace ck_tile
...@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem ...@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
// extract tile size attributes to remove dependency on traits
template <typename OaccDataType_, ck_tile::index_t kN1_>
struct BlockFmhaSplitKVCombinePipelineTileSizes
{
static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
static constexpr index_t kN1 = kN1_;
static constexpr index_t NThreads = kN1 / MaxVectorSize;
static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
};
template <typename LSEDataType_, template <typename LSEDataType_,
typename OaccDataType_, typename OaccDataType_,
typename ODataType_, typename ODataType_,
index_t HeadDimV_, index_t HeadDimV_,
index_t kM0_,
index_t kN1_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::index_t kN1_,
typename Traits_> typename Traits_>
struct BlockFmhaSplitKVCombinePipelineProblem struct BlockFmhaSplitKVCombinePipelineProblem
: BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
{ {
using BaseType = BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>;
using LSEDataType = remove_cvref_t<LSEDataType_>; using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>; using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4); static_assert(std::is_same_v<LSEDataType, OaccDataType>);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_; static constexpr index_t kHeadDimV = HeadDimV_;
static constexpr index_t kM0 = kM0_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kN1 = kN1_;
using BaseType::kM0;
using BaseType::kN1;
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits; static constexpr index_t kMaxSplits = Traits::kMaxSplits;
static_assert(8 <= kMaxSplits);
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
(kM0 * kMaxSplits) % get_warp_size() == 0);
}; };
template <typename QDataType_, template <typename QDataType_,
......
...@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
} }
template <typename Problem, typename BlockGemm> template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); return BlockGemm::template MakeABlockTileDistribution<
using WG = remove_cvref_t<decltype(config.template at<0>())>; Problem::BlockFmhaShape::kM0,
constexpr index_t MWarp = config.template at<1>(); Problem::BlockFmhaShape::kSubQKHeaddim>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
} }
template <typename Problem> template <typename Problem>
...@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
...@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{ {
if constexpr(WarpGemmM == 32) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
...@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{ {
if constexpr(WarpGemmM == 32) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -43,8 +43,6 @@ struct TileFmhaShape ...@@ -43,8 +43,6 @@ struct TileFmhaShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -130,7 +130,8 @@ struct MoeSortingKernel ...@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{ {
const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t); // usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
} }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
...@@ -154,6 +155,75 @@ struct MoeSortingKernel ...@@ -154,6 +155,75 @@ struct MoeSortingKernel
return k; return k;
} }
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const
{
// wave_size must be power of 2
constexpr int row_mask = 0xf;
constexpr int bank_mask = 0xf;
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
if constexpr(wave_size > 1)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x111,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:1
}
if constexpr(wave_size > 2)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x112,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:2
}
if constexpr(wave_size > 4)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x114,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:4
}
if constexpr(wave_size > 8)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
}
if constexpr(wave_size > 16)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
if constexpr(wave_size > 32)
{
// lane-id 48...63->31
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{ {
return row * total_col + col; return row * total_col + col;
...@@ -187,48 +257,124 @@ struct MoeSortingKernel ...@@ -187,48 +257,124 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem); index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i) for(int i = 0; i < num_experts; ++i)
{ {
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
} }
#pragma unroll Problem_::InternalLoadUnroll #pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
} }
__syncthreads(); __syncthreads();
#if 1
if(tid < num_experts) if(tid < num_experts)
{ {
tokens_cnts[calc_index(num_experts, 0, tid)] = 0; tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i) index_t local_c[8];
index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
{ {
tokens_cnts[calc_index(num_experts, i, tid)] += local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
tokens_cnts[calc_index(num_experts, i - 1, tid)]; local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
local_c[0] += prev_c;
local_c[1] += local_c[0];
local_c[2] += local_c[1];
local_c[3] += local_c[2];
local_c[4] += local_c[3];
local_c[5] += local_c[4];
local_c[6] += local_c[5];
local_c[7] += local_c[6];
prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
} }
} }
#else
// __syncthreads(); // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
if(tid == 0)
{ {
cumsum[0] = 0; if(tid < num_experts)
for(int i = 1; i <= num_experts; ++i) tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) {
index_t local_c[8];
#pragma unroll
for(int j = 0; j < 8; j++) {
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
}
#pragma unroll
for(int j = 0; j < 8; j++) {
wave_cumsum<int, 64>(local_c[j]);
}
#pragma unroll
for(int j = 0; j < 8; j++) {
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
}
}
}
#endif
__syncthreads();
if constexpr (Problem::ExpertTile == 0) {
if(tid == 0)
{ {
auto current_units = [&]() { cumsum[0] = 0;
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] + for(int i = 1; i <= num_experts; ++i)
unit_size_mdiv.divisor - 1; {
index_t y_ = unit_size_mdiv.div(x_); auto current_units = [&]() {
return max(y_, 1) * unit_size_mdiv.divisor; index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
}(); unit_size_mdiv.divisor - 1;
cumsum[i] = cumsum[i - 1] + current_units; index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
cumsum[i] = cumsum[i - 1] + current_units;
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
} else {
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) {
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum;
}
if(tid < num_experts) {
cumsum[tid + 1] = local_cumsum;
} }
*p_total_tokens_post_pad = cumsum[num_experts];
} }
__syncthreads(); __syncthreads();
if(tid < num_experts) if(tid < num_experts)
{ {
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor) int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
} }
...@@ -238,8 +384,8 @@ struct MoeSortingKernel ...@@ -238,8 +384,8 @@ struct MoeSortingKernel
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
index_t rank_post_pad = index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id; uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id); topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
...@@ -247,27 +393,54 @@ struct MoeSortingKernel ...@@ -247,27 +393,54 @@ struct MoeSortingKernel
#else #else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif #endif
p_sorted_weights[rank_post_pad] = weights[i]; p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)]; tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
} }
const index_t prefill_token = topk_mdiv.div(numel); if constexpr (Problem::ExpertTile == 0) {
if(tid < num_experts) const index_t prefill_token = topk_mdiv.div(numel);
{ if(tid < num_experts)
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
{ {
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] = p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor); MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else #else
p_sorted_token_ids[expert_offset] = prefill_token; p_sorted_token_ids[expert_offset] = prefill_token;
#endif #endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0); p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++; expert_offset++;
}
} }
} }
else {
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset =
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) {
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave;
}
}
}
}
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
......
...@@ -9,15 +9,20 @@ ...@@ -9,15 +9,20 @@
namespace ck_tile { namespace ck_tile {
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_> template <typename IndexType_,
typename WeightType_,
index_t InternalLoadUnroll_,
index_t ExpertTile_ = 0>
struct MoeSortingProblem struct MoeSortingProblem
{ {
// TODO: this kernel only support warp per row // TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>; using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>; using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size(); static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1; static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_; static constexpr index_t InternalLoadUnroll =
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1 ...@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
const index_t iNWarp = 0; const index_t iNWarp = 0;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>, tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
...@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1 ...@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp // constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution // distribution
auto a_block_tensor = auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr); MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
...@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1 ...@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1
}); });
} }
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t MPerBlock = BlockGemmShape::kM;
......
...@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
const index_t iNWarp = get_warp_id() % NWarp; const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
...@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp // constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution // distribution
auto a_block_tensor = auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr); MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
...@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
}); });
} }
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t MPerBlock = BlockGemmShape::kM;
......
...@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using KernelArgs = BatchedGemmKernelArgs; using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count) __host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{ {
return TilePartitioner::GridSize(M, N, batch_count); return TilePartitioner::GridSize(M, N, KBatch * batch_count);
} }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
...@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
hostArgs.K, hostArgs.K,
hostArgs.stride_A, hostArgs.stride_A,
hostArgs.stride_B, hostArgs.stride_B,
hostArgs.stride_C}, hostArgs.stride_C,
hostArgs.k_batch},
hostArgs.batch_stride_A, hostArgs.batch_stride_A,
hostArgs.batch_stride_B, hostArgs.batch_stride_B,
hostArgs.batch_stride_C, hostArgs.batch_stride_C,
...@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A; const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
splitk_batch_offset.a_k_split_offset;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B; const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
splitk_batch_offset.b_k_split_offset;
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C; CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); // allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
{
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
this->template RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -93,6 +93,7 @@ struct GemmKernel ...@@ -93,6 +93,7 @@ struct GemmKernel
index_t stride_A; index_t stride_A;
index_t stride_B; index_t stride_B;
index_t stride_C; index_t stride_C;
index_t KBatch;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
...@@ -105,28 +106,72 @@ struct GemmKernel ...@@ -105,28 +106,72 @@ struct GemmKernel
hostArgs.K, hostArgs.K,
hostArgs.stride_A, hostArgs.stride_A,
hostArgs.stride_B, hostArgs.stride_B,
hostArgs.stride_C}; hostArgs.stride_C,
hostArgs.k_batch};
} }
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{
if(kargs.KBatch != 1)
{
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
...@@ -198,17 +243,19 @@ struct GemmKernel ...@@ -198,17 +243,19 @@ struct GemmKernel
return true; return true;
} }
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
const BDataType* b_ptr, CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
CDataType* c_ptr, const BDataType* b_ptr,
const GemmKernelArgs& kargs) const CDataType* c_ptr,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{ {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
...@@ -217,7 +264,7 @@ struct GemmKernel ...@@ -217,7 +264,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_A), make_tuple(1, kargs.stride_A),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
...@@ -229,7 +276,7 @@ struct GemmKernel ...@@ -229,7 +276,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_ptr, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_B), make_tuple(1, kargs.stride_B),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
...@@ -238,7 +285,7 @@ struct GemmKernel ...@@ -238,7 +285,7 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_ptr, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
...@@ -248,7 +295,7 @@ struct GemmKernel ...@@ -248,7 +295,7 @@ struct GemmKernel
const auto& c_tensor_view = [&]() { const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
...@@ -257,7 +304,7 @@ struct GemmKernel ...@@ -257,7 +304,7 @@ struct GemmKernel
} }
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C), make_tuple(1, kargs.stride_C),
...@@ -270,7 +317,7 @@ struct GemmKernel ...@@ -270,7 +317,7 @@ struct GemmKernel
} }
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{ {
const auto& a_pad_view = [&]() { const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0); const auto& a_tensor_view = views.at(I0);
...@@ -330,8 +377,8 @@ struct GemmKernel ...@@ -330,8 +377,8 @@ struct GemmKernel
} }
template <typename PadView> template <typename PadView>
CK_TILE_DEVICE auto CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{ {
const auto& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window( const auto& a_block_window = make_tile_window(
...@@ -363,23 +410,27 @@ struct GemmKernel ...@@ -363,23 +410,27 @@ struct GemmKernel
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/ */
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
const BDataType* b_ptr, CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
CDataType* c_ptr, const BDataType* b_ptr,
const GemmKernelArgs& kargs, CDataType* c_ptr,
const index_t block_idx_m, void* smem_ptr,
const index_t block_idx_n) const const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{ {
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); const auto& gemm_tensor_views_tuple =
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); ;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
// allocate LDS auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
...@@ -389,18 +440,43 @@ struct GemmKernel ...@@ -389,18 +440,43 @@ struct GemmKernel
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const SplitKBatchOffset splitk_batch_offset(kargs);
// options // options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr =
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr); static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr); const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); if(kargs.KBatch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -13,6 +13,8 @@ namespace ck_tile { ...@@ -13,6 +13,8 @@ namespace ck_tile {
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
static constexpr bool TransposeC = false;
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0; constexpr index_t smem_size = smem_size_a + smem_size_b;
smem_size += smem_size_a + smem_size_b;
return smem_size; return smem_size;
} }
...@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
constexpr bool TransposeC = false; constexpr auto I0 = number<0>{};
constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{};
constexpr auto I1 = number<1>{}; constexpr auto I2 = number<2>{};
constexpr auto I2 = number<2>{};
using AccDataType = float; using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
......
...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
......
...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma ...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCNLane>>, {
tuple<sequence<1, 2>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<1, 1>, tuple<sequence<Impl::kBNLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kBNBlock * Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
template <bool post_nop_ = false> template <bool post_nop_ = false>
...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCNLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>, {
tuple<sequence<2, 1>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<2, 2>, tuple<sequence<Impl::kAMLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
template <bool post_nop_ = false> template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane), tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment