Commit 019d4b7c authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 07307ea1 5063a39f
...@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
const index_t iNWarp = get_warp_id() % NWarp; const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
...@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp // constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution // distribution
auto a_block_tensor = auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr); MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
...@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2 ...@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
}); });
} }
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM; constexpr index_t MPerBlock = BlockGemmShape::kM;
......
...@@ -3,90 +3,95 @@ ...@@ -3,90 +3,95 @@
#pragma once #pragma once
#include <iostream> #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile { namespace ck_tile {
struct BatchedGemmHostArgs struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{ {
const void* a_ptr; CK_TILE_HOST BatchedGemmHostArgs() = default;
const void* b_ptr; CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
void* c_ptr; const void* b_ptr_,
index_t M; void* c_ptr_,
index_t N; ck_tile::index_t k_batch_,
index_t K; ck_tile::index_t M_,
index_t stride_A; ck_tile::index_t N_,
index_t stride_B; ck_tile::index_t K_,
index_t stride_C; ck_tile::index_t stride_A_,
index_t batch_stride_A; ck_tile::index_t stride_B_,
index_t batch_stride_B; ck_tile::index_t stride_C_,
index_t batch_stride_C; ck_tile::index_t batch_stride_A_,
index_t batch_count; ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
}; };
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using GemmKernelArgs = typename Base::GemmKernelArgs;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = typename Base::ADataType;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = typename Base::BDataType;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = typename Base::CDataType;
struct BatchedGemmKargs using TilePartitioner = typename Base::TilePartitioner;
using GemmPipeline = typename Base::GemmPipeline;
using EpiloguePipeline = typename Base::EpiloguePipeline;
using ALayout = typename Base::ALayout;
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
struct BatchedGemmKernelArgs : GemmKernelArgs
{ {
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A; index_t batch_stride_A;
index_t batch_stride_B; index_t batch_stride_B;
index_t batch_stride_C; index_t batch_stride_C;
index_t batch_count; index_t batch_count;
}; };
using Kargs = BatchedGemmKargs; using KernelArgs = BatchedGemmKernelArgs;
using Hargs = BatchedGemmHostArgs;
__host__ static constexpr auto GridSize(const Hargs& h) __host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{ {
return TilePartitioner::GridSize(h.M, h.N, h.batch_count); return TilePartitioner::GridSize(M, N, KBatch * batch_count);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr BatchedGemmKernelArgs
MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
{ {
Kargs k; return BatchedGemmKernelArgs{{hostArgs.a_ptr,
k.a_ptr = h.a_ptr; hostArgs.b_ptr,
k.b_ptr = h.b_ptr; hostArgs.c_ptr,
k.c_ptr = h.c_ptr; hostArgs.M,
k.M = h.M; hostArgs.N,
k.N = h.N; hostArgs.K,
k.K = h.K; hostArgs.stride_A,
k.stride_A = h.stride_A; hostArgs.stride_B,
k.stride_B = h.stride_B; hostArgs.stride_C,
k.stride_C = h.stride_C; hostArgs.k_batch},
k.batch_stride_A = h.batch_stride_A; hostArgs.batch_stride_A,
k.batch_stride_B = h.batch_stride_B; hostArgs.batch_stride_B,
k.batch_stride_C = h.batch_stride_C; hostArgs.batch_stride_C,
k.batch_count = h.batch_count; hostArgs.batch_count};
return k;
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -94,164 +99,41 @@ struct BatchedGemmKernel ...@@ -94,164 +99,41 @@ struct BatchedGemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
splitk_batch_offset.a_k_split_offset;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
splitk_batch_offset.b_k_split_offset;
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
// clang-format on
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() { // allocate LDS
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) __shared__ char smem_ptr[GetSmemSize()];
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile); if(kargs.KBatch == 1)
{
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
this->template RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -12,6 +12,50 @@ ...@@ -12,6 +12,50 @@
namespace ck_tile { namespace ck_tile {
struct GemmProblem
{
CK_TILE_HOST GemmProblem() = default;
CK_TILE_HOST GemmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public GemmProblem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel struct GemmKernel
{ {
...@@ -25,9 +69,12 @@ struct GemmKernel ...@@ -25,9 +69,12 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
return TilePartitioner::GridSize(M, N, KBatch); return TilePartitioner::GridSize(M, N, KBatch);
...@@ -35,7 +82,7 @@ struct GemmKernel ...@@ -35,7 +82,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs struct GemmKernelArgs
{ {
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
...@@ -46,19 +93,21 @@ struct GemmKernel ...@@ -46,19 +93,21 @@ struct GemmKernel
index_t stride_A; index_t stride_A;
index_t stride_B; index_t stride_B;
index_t stride_C; index_t stride_C;
index_t KBatch;
}; };
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{ {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; return GemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.k_batch};
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -66,8 +115,63 @@ struct GemmKernel ...@@ -66,8 +115,63 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs) struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{
if(kargs.KBatch != 1)
{
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
...@@ -139,19 +243,19 @@ struct GemmKernel ...@@ -139,19 +243,19 @@ struct GemmKernel
return true; return true;
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto& a_tensor_view = [&]() {
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
...@@ -159,20 +263,20 @@ struct GemmKernel ...@@ -159,20 +263,20 @@ struct GemmKernel
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_ptr,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_A), make_tuple(1, kargs.stride_A),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
} }
}(); }();
auto b_tensor_view = [&]() { const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(1, kargs.stride_B), make_tuple(1, kargs.stride_B),
number<1>{}, number<1>{},
number<1>{}); number<1>{});
...@@ -180,15 +284,43 @@ struct GemmKernel ...@@ -180,15 +284,43 @@ struct GemmKernel
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_ptr,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
} }
}(); }();
auto a_pad_view = [&]() { const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -204,14 +336,9 @@ struct GemmKernel ...@@ -204,14 +336,9 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() { const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -228,43 +355,8 @@ struct GemmKernel ...@@ -228,43 +355,8 @@ struct GemmKernel
} }
}(); }();
auto b_block_window = make_tile_window( const auto& c_pad_view = [&]() {
b_pad_view, const auto& c_tensor_view = views.at(I2);
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -280,12 +372,111 @@ struct GemmKernel ...@@ -280,12 +372,111 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{}); sequence<GemmPipeline::kPadM, false>{});
} }
}(); }();
auto CBlockWindow_pad = make_tile_window(
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); return make_tuple(a_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
}
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
}; };
......
...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>; using I0 = number<0>;
using I2 = number<2>; using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
...@@ -132,6 +133,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -132,6 +133,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
......
...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
...@@ -53,6 +55,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -53,6 +55,8 @@ struct GemmPipelineAGmemBGmemCRegV1
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
...@@ -124,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -124,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = BlockGemm();
// Acc register tile // Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
......
...@@ -12,6 +12,11 @@ namespace ck_tile { ...@@ -12,6 +12,11 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
#if 0 #if 0
// 2d // 2d
...@@ -114,8 +119,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -114,8 +119,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0; constexpr index_t smem_size = smem_size_a + smem_size_b;
smem_size += smem_size_a + smem_size_b;
return smem_size; return smem_size;
} }
...@@ -485,14 +489,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -485,14 +489,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float; using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
......
...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2 ...@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
......
...@@ -11,7 +11,6 @@ namespace ck_tile { ...@@ -11,7 +11,6 @@ namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
...@@ -444,6 +443,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -444,6 +443,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
......
...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma ...@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCNLane>>, {
tuple<sequence<1, 2>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<1, 1>, tuple<sequence<Impl::kBNLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kBNBlock * Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
template <bool post_nop_ = false> template <bool post_nop_ = false>
...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
sequence<>, "Multi-block on both M & N directions is not supported");
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
tuple<sequence<2, 1>>, {
tuple<sequence<0, 0>>, return tile_distribution_encoding<
sequence<2>, sequence<>,
sequence<1>>; tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
using CWarpDstrEncoding = tile_distribution_encoding< CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
sequence<>, {
tuple<sequence<Impl::kCNLane>, if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>, {
tuple<sequence<2, 1>>, return tile_distribution_encoding<
tuple<sequence<1, 0>>, sequence<>,
sequence<2, 2>, tuple<sequence<Impl::kAMLane>,
sequence<0, 2>>; sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
template <bool post_nop_ = false> template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane), tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
......
...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// Bf16 // Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 8; static constexpr index_t kK = 8;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kN = 16; static constexpr index_t kN = 16;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16; static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16; static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4; static constexpr index_t kABKLane = 4;
...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 4;
static constexpr index_t kN = 64;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 16;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 64;
static constexpr index_t kN = 4;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 16;
static constexpr index_t kBNBlock = 1;
// we only write down single block (4 threads) thread mapping here
static constexpr index_t kAMLane = 4;
static constexpr index_t kBNLane = 4;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 1;
static constexpr index_t kCNLane = 4;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// FP8 // FP8
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_> template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 ...@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static constexpr index_t kN = 32; static constexpr index_t kN = 32;
static constexpr index_t kK = 16; static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32; static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32; static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2; static constexpr index_t kABKLane = 2;
......
...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float ...@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float ...@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,8 @@ struct Layernorm2dFwdHostArgs ...@@ -14,7 +14,8 @@ struct Layernorm2dFwdHostArgs
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -42,15 +43,16 @@ struct Layernorm2dFwd ...@@ -42,15 +43,16 @@ struct Layernorm2dFwd
using Epilogue = remove_cvref_t<Epilogue_>; using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem; using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>; using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>; using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>; using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>; using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>; using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>; using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>; using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
// for simplicity, shortcut input/output type is same as X // for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType; using XResidualDataType = XDataType;
...@@ -67,6 +69,7 @@ struct Layernorm2dFwd ...@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass; static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -81,7 +84,8 @@ struct Layernorm2dFwd ...@@ -81,7 +84,8 @@ struct Layernorm2dFwd
{ {
const void* p_x; // [m ,n], input, fp16/bf16 const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_x_bias; // [1, n], bias, prec same as input
const void* p_gamma; // [1, n], gamma, prec same as input const void* p_gamma; // [1, n], gamma, prec same as input
const void* p_beta; // [1, n], beta, prec same as input const void* p_beta; // [1, n], beta, prec same as input
...@@ -107,7 +111,8 @@ struct Layernorm2dFwd ...@@ -107,7 +111,8 @@ struct Layernorm2dFwd
{ {
return Kargs{hargs.p_x, return Kargs{hargs.p_x,
hargs.p_x_residual, hargs.p_x_residual,
hargs.p_x_scale, hargs.p_sm_scale,
hargs.p_x_bias,
hargs.p_gamma, hargs.p_gamma,
hargs.p_beta, hargs.p_beta,
hargs.p_y, hargs.p_y,
...@@ -152,6 +157,7 @@ struct Layernorm2dFwd ...@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
auto surfix = [&] () { auto surfix = [&] () {
std::string n; std::string n;
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name; if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name; if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
...@@ -165,7 +171,7 @@ struct Layernorm2dFwd ...@@ -165,7 +171,7 @@ struct Layernorm2dFwd
base_str += _SS_("_") + _SS_(t2s<YDataType>::name); base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name); base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name); base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
} }
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
...@@ -228,6 +234,27 @@ struct Layernorm2dFwd ...@@ -228,6 +234,27 @@ struct Layernorm2dFwd
} }
}(); }();
const auto x_bias_window = [&]() {
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XBiasDataType*>(kargs.p_x_bias),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
const auto gamma_window = [&]() { const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
...@@ -329,18 +356,18 @@ struct Layernorm2dFwd ...@@ -329,18 +356,18 @@ struct Layernorm2dFwd
return make_null_tile_window(make_tuple(number<Block_M>{})); return make_null_tile_window(make_tuple(number<Block_M>{}));
}(); }();
auto x_scale_window = [&]() { auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
const auto win_ = [&]() { const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>( const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_x_scale), static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n), make_tuple(kargs.n),
number<Vector_N>{}); number<Vector_N>{});
return pad_tensor_view(tmp_0_, return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}), make_tuple(number<Block_N>{}),
sequence<false>{}); // x_scale no need pad sequence<false>{}); // sm_scale no need pad
}(); }();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
} }
...@@ -371,13 +398,14 @@ struct Layernorm2dFwd ...@@ -371,13 +398,14 @@ struct Layernorm2dFwd
Pipeline{}(x_window, Pipeline{}(x_window,
x_residual_window, x_residual_window,
x_bias_window,
gamma_window, gamma_window,
beta_window, beta_window,
y_window, y_window,
y_residual_window, y_residual_window,
mean_window, mean_window,
inv_std_window, inv_std_window,
x_scale_window, sm_scale_window,
y_scale_window, y_scale_window,
static_cast<const ComputeDataType>(kargs.epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n, kargs.n,
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelford<P_>{}; return BlockNormReduce<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordSync<P_>{}; return BlockNormReduceSync<P_>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockWelfordCrossWarpSync<P_>{}; return BlockNormReduceCrossWarpSync<P_>{};
} }
template <typename Problem> template <typename Problem>
...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
Problem::Traits::kFastFDiv>; Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
using block_welford = BlockWelford<P_>; using block_welford = BlockNormReduce<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile = using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>()); decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
return GetBlockWelfordCrossWarpSync<Problem>() return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>(); .template GetSmemSize<mean_var_block_tile>();
} }
else else
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -37,6 +38,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -37,6 +38,8 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -54,24 +57,26 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -54,24 +57,26 @@ struct Layernorm2dFwdPipelineOnePass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window_, YWindow& y_window_,
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_, const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -80,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -80,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{ {
const auto x_window = const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window( const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window( const auto beta_window = make_tile_window(
...@@ -89,23 +96,38 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -89,23 +96,38 @@ struct Layernorm2dFwdPipelineOnePass
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(mean);
clear_tile(var);
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -117,12 +139,21 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -117,12 +139,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc)); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
} }
// compute welford each-thread->cross-lane->cross-warp // compute reduce each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); if(kWelford)
{
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
}
else
{
sweep_tile(mean, [&](auto idx) {
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
});
}
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
...@@ -153,14 +184,13 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -153,14 +184,13 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_;
ln(idx) = ln_;
}); });
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,28 +8,30 @@ ...@@ -8,28 +8,30 @@
namespace ck_tile { namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename XBiasDataType_,
typename GammaDataType_, typename GammaDataType_,
typename BetaDataType_, typename BetaDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename BlockShape_, typename BlockShape_,
typename Traits_> typename Traits_>
struct Layernorm2dFwdPipelineProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using XBiasDataType = remove_cvref_t<XBiasDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>; using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
...@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kXbias = Problem::Traits::kXbias;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -53,32 +56,37 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -53,32 +56,37 @@ struct Layernorm2dFwdPipelineTwoPass
template <typename XWindow, template <typename XWindow,
typename XResidualWindow, typename XResidualWindow,
typename XBiasWindow,
typename GammaWindow, typename GammaWindow,
typename BetaWindow, typename BetaWindow,
typename YWindow, typename YWindow,
typename YResidualWindow, typename YResidualWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename XScaleWindow, typename SmoothScaleWindow,
typename YScaleWindow, typename YScaleWindow,
typename Epilogue> typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_, const XResidualWindow& x_residual_window_,
const XBiasWindow& x_bias_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
const BetaWindow& beta_window_, const BetaWindow& beta_window_,
YWindow& y_window, YWindow& y_window,
const YResidualWindow& y_residual_window_, const YResidualWindow& y_residual_window_,
MeanWindow& mean_window, MeanWindow& mean_window,
InvStdWindow& inv_std_window, InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/, const SmoothScaleWindow& /*sm_scale_window*/,
YScaleWindow& /*y_scale_window*/, YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem, void* smem,
Epilogue) const Epilogue) const
{ {
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto x_bias_window = make_tile_window(
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window( auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window( auto beta_window = make_tile_window(
...@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count = int max_count =
(num_n_tile_iteration - 1) * count_per_iter + (num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n); block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_welford_cross_warp_sync = auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>(); Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window))); using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>(); auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
const auto x_bias = load_tile(x_bias_window);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N});
move_tile_window(x_bias_window, {Block_N});
auto acc = cast_tile<ComputeDataType>(x); auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
...@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N}); move_tile_window(y_residual_window, {0, Block_N});
} }
} }
block_welford(acc, mean, var, cur_count, max_count); block_norm_reduce(acc, mean, var, cur_count, max_count);
} }
block_welford_sync(mean, var, cur_count); block_norm_reduce_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{}); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
...@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
...@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation // layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x); const auto x_bias = load_tile(x_bias_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
{
sweep_tile(x, [&](auto idx) {
// compute x = bias + x
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
});
}
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
...@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(x_bias_window, {-Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N}); move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment