Unverified Commit 9533a172 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents c2cf0733 50ee4267
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
struct BatchedGemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
struct BatchedGemmKargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHostArgs;
__host__ static constexpr auto GridSize(const Hargs& h)
{
return TilePartitioner::GridSize(h.M, h.N, h.batch_count);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h)
{
Kargs k;
k.a_ptr = h.a_ptr;
k.b_ptr = h.b_ptr;
k.c_ptr = h.c_ptr;
k.M = h.M;
k.N = h.N;
k.K = h.K;
k.stride_A = h.stride_A;
k.stride_B = h.stride_B;
k.stride_C = h.stride_C;
k.batch_stride_A = h.batch_stride_A;
k.batch_stride_B = h.batch_stride_B;
k.batch_stride_C = h.batch_stride_C;
k.batch_count = h.batch_count;
return k;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
// clang-format on
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile);
}
};
} // namespace ck_tile
...@@ -115,12 +115,22 @@ struct GemmKernel ...@@ -115,12 +115,22 @@ struct GemmKernel
} }
}(); }();
auto a_pad_view = pad_tensor_view( auto a_pad_view = [&]() {
a_tensor_view, if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), {
// somehow clang-format is splitting below line into multiple. return pad_tensor_view(
// clang-format off a_tensor_view,
sequence<false, GemmPipeline::kPadA>{}); make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on // clang-format on
auto a_block_window = make_tile_window( auto a_block_window = make_tile_window(
...@@ -128,12 +138,22 @@ struct GemmKernel ...@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto b_pad_view = pad_tensor_view( auto b_pad_view = [&]() {
b_tensor_view, if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadB>{}); b_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
...@@ -171,18 +191,28 @@ struct GemmKernel ...@@ -171,18 +191,28 @@ struct GemmKernel
} }
}(); }();
auto c_pad_view = pad_tensor_view( auto c_pad_view = [&]() {
c_tensor_view, if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadC>{}); c_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
auto c_block_window = make_tile_window( sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmPipelineAgBgCrImplBase
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView>
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
{
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView>
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
{
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV3
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
} // namespace ck_tile
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrMem<Problem>; using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>; using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
...@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ??? // Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base::PrefetchStages; using Base::PrefetchStages;
CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl struct PipelineImpl : public PipelineImplBase
{ {
}; };
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{ {
template <typename DstBlockTile, typename SrcTileWindow> using Base = PipelineImplBase;
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop, template <bool HasHotLoop,
TailNumber TailNum, TailNumber TailNum,
...@@ -185,70 +162,42 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -185,70 +162,42 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); " or KPerBlock!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
// A tile in LDS // A/B tiles in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); // With c++20 could simplify to below line.
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); // Currently get error: captured structured bindings are a C++20 extension
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
// TODO: LDS alignment should come from Policy! auto& a_lds_block = ab_lds_blocks.at(I0{});
constexpr index_t a_lds_block_space_size_aligned = auto& b_lds_block = ab_lds_blocks.at(I1{});
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = // A LDS tile for block GEMM
make_tile_window(a_lds_block, auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), auto& a_copy_dram_window = a_windows.at(I0{});
{0, 0}, auto& a_copy_lds_window = a_windows.at(I1{});
a_copy_dram_window.get_tile_distribution()); auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B DRAM tile window for load
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM // B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window( auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM // Block GEMM
constexpr auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile(); auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
}); });
// main body // main body
...@@ -290,24 +239,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -290,24 +239,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
LocalPrefill( Base::LocalPrefill(
a_copy_lds_window, a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func); a_element_func);
LocalPrefill( Base::LocalPrefill(
b_copy_lds_window, b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func); b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -318,27 +267,208 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -318,27 +267,208 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}), Base::LocalPrefill(a_copy_lds_window,
a_element_func); a_block_tiles.get(number<prefetch_idx>{}),
LocalPrefill(b_copy_lds_window, a_element_func);
b_block_tiles.get(number<prefetch_idx>{}), Base::LocalPrefill(b_copy_lds_window,
b_element_func); b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
});
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else if constexpr(TailNum == TailNumber::Two)
{
HotLoopTail(number<2>{});
}
else if constexpr(TailNum == TailNumber::Three)
{
HotLoopTail(number<3>{});
}
else if constexpr(TailNum == TailNumber::Four)
{
HotLoopTail(number<4>{});
}
else if constexpr(TailNum == TailNumber::Five)
{
HotLoopTail(number<5>{});
}
else if constexpr(TailNum == TailNumber::Six)
{
HotLoopTail(number<6>{});
}
else if constexpr(TailNum == TailNumber::Seven)
{
HotLoopTail(number<7>{});
}
else if constexpr(TailNum == TailNumber::Full)
{
HotLoopTail(number<PrefetchStages>{});
}
return c_block_tile;
}
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
auto& a_copy_dram_window = a_windows.at(I0{});
auto& a_copy_lds_window = a_windows.at(I1{});
auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
});
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
});
i += PrefetchStages;
} while(i < (num_loop - PrefetchStages));
}
auto HotLoopTail = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}); });
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}; };
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
{ {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
} }
else if constexpr(TailNum == TailNumber::Two) else if constexpr(TailNum == TailNumber::Two)
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
enum struct GemmPipelineScheduler enum struct GemmPipelineScheduler
{ {
Default,
Intrawave, Intrawave,
Interwave, Interwave,
}; };
...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch ...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{ {
switch(s) switch(s)
{ {
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break; case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break; case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
default: os << ""; default: os << "";
......
...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block, a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load // B DRAM tile window for load
auto b_copy_dram_window = auto b_copy_dram_window =
...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block, b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM // A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
...@@ -130,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -130,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile // Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
store_tile(a_copy_lds_window, a_block_tile_tmp); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0 // LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, a_block_tile_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1 // LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--; iCounter--;
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -11,6 +12,7 @@ namespace ck_tile { ...@@ -11,6 +12,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
...@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return smem_size; return smem_size;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> template <typename Problem>
...@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t M2 = get_warp_size() / K0; {
#if 1 // coalesce reading for each blocks constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = MPerBlock / M1;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_static_tile_distribution( static_assert(KPack % K3 == 0);
tile_distribution_encoding<sequence<1>, constexpr index_t K2 = KPack / K3;
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, if constexpr(get_warp_size() % (K2 * M0))
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
sequence<1, 2>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<0, 1>>{}); static_assert(KPerBlock == K0 * K1 * K2 * K3);
#else // coalesce reading for each warps return make_static_tile_distribution(
constexpr index_t M0 = kBlockSize / get_warp_size(); tile_distribution_encoding<sequence<1>,
constexpr index_t M1 = kMPerBlock / (M2 * M0); tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
return make_static_tile_distribution( tuple<sequence<0>, sequence<1, 0, 2>>,
tile_distribution_encoding<sequence<1>, sequence<2, 1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, sequence<3, 1>>{});
tuple<sequence<1>, sequence<1, 2>>, }
tuple<sequence<0>, sequence<2, 0>>, else
sequence<1, 2>, {
sequence<1, 1>>{}); constexpr index_t K1 = (K2 * M0) / get_warp_size();
#endif constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
#if 1 // coalesce reading for each blocks static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = kBlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
#else // coalesce reading for each warps tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
constexpr index_t N0 = kBlockSize / get_warp_size(); tuple<sequence<2>, sequence<2, 1, 2>>,
constexpr index_t N1 = kNPerBlock / (N2 * N0); tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
return make_static_tile_distribution( sequence<1, 3>>{});
tile_distribution_encoding<sequence<1>, }
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, else
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<0>, sequence<2, 0>>, constexpr index_t K1 = (K2 * N0) / get_warp_size();
sequence<1, 2>, constexpr index_t K2_m = K2 / K1;
sequence<1, 1>>{}); constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
#endif static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{ {
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{}; template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -3,40 +3,135 @@ ...@@ -3,40 +3,135 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
static constexpr int _VectorSize = 16;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename TileGemmTraits_>
struct GemmPipelineProblem struct GemmPipelineProblemBase
{ {
using ADataType = remove_cvref_t<ADataType_>; using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>; using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>; using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>; using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr bool kPadA = GemmTraits::kPadA; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC; static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
}
}
static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); {
static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
}; };
// Alias for GemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
...@@ -45,30 +140,15 @@ template <typename ADataType_, ...@@ -45,30 +140,15 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{ {
using ADataType = remove_cvref_t<ADataType_>; static constexpr auto Scheduler = Scheduler_;
using BDataType = remove_cvref_t<BDataType_>; static constexpr auto HasHotLoop = HasHotLoop_;
using CDataType = remove_cvref_t<CDataType_>; static constexpr auto TailNum = TailNum_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = GemmTraits::kPadA;
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC;
static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1;
static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1;
static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -9,12 +9,8 @@ ...@@ -9,12 +9,8 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr bool TransposeC = true; static constexpr bool TransposeC = true;
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 4)
{
return (4 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 2)
{
return (2 / sizeof(DataType));
}
else
{
return 1;
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t K0 = KPerBlock / K1;
constexpr auto DataTypeSize = sizeof(ADataType);
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value) constexpr auto MLdsLayer =
{ (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
: 32 * 4 / KPerBlock / sizeof(ADataType); make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( number<MPerBlock / MLdsLayer>{},
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1), number<KPack>{}),
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1)); make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( number<1>{});
a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{}, constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
number<K0 * MLdsLayer>{})), a_lds_block_desc_0,
make_pass_through_transform(K1)), make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
make_tuple(sequence<1, 0>{}, sequence<2>{}), number<KPerBlock / KPack * MLdsLayer>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( make_tuple(sequence<1, 0>{}, sequence<2>{}));
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})), constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}), a_lds_block_desc_permuted,
make_pass_through_transform(K1)), make_tuple(make_unmerge_transform(
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(number<KPerBlock / KPack>{}, number<MLdsLayer>{})),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
a_lds_block_desc_ak0_kMLdsLayer_m_ak1, make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))), a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(make_merge_transform_v3_division_mod(
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
return a_lds_block_desc_m_k; make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
} make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
else // ColumnMajor A make_tuple(sequence<0>{}, sequence<1>{}));
{
// kfold and mpair dimension is not always required. return a_lds_block_desc;
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kM;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=kN0
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
? M0
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
K1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1; constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value) (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
{
// NLdsLayer * K0 as logical Bank constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 make_tuple(number<KPerBlock / KPack * NLdsLayer>{},
? 1 number<NPerBlock / NLdsLayer>{},
: 32 * 4 / KPerBlock / sizeof(BDataType); number<KPack>{}),
; make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( number<KPack>{},
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1), number<1>{});
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0,
b_lds_block_desc, make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{}, number<KPerBlock / KPack * NLdsLayer>{})),
number<K0 * NLdsLayer>{})), make_pass_through_transform(number<KPack>{})),
make_pass_through_transform(K1)), make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted,
b_lds_block_desc_permuted, make_tuple(make_unmerge_transform(
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})), make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}), make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(K1)), make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1, b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), make_tuple(make_merge_transform_v3_division_mod(
make_merge_transform_v3_division_mod( make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))), make_merge_transform_v3_division_mod(
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k; return b_lds_block_desc;
}
else // RowMajor B
{
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kN;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=kN0
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
? N0
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
K1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
} }
template <typename Problem> template <typename Problem>
...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using ADataType = remove_cvref_t<typename Problem::ADataType>;
typename Problem::BDataType, using ALayout = remove_cvref_t<typename Problem::ALayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t K0 = KPerBlock / K1; {
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t K3 = total_pixels / M1;
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(KPack % K3 == 0);
return make_static_tile_distribution( constexpr index_t K2 = KPack / K3;
tile_distribution_encoding<sequence<1>, if constexpr(get_warp_size() % (K2 * M0) == 0)
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<1, 2>, static_assert(KPerBlock == K0 * K1 * K2 * K3);
sequence<0, 1>>{}); return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using BDataType = remove_cvref_t<typename Problem::BDataType>;
typename Problem::BDataType, using BLayout = remove_cvref_t<typename Problem::BLayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0), constexpr index_t BlockSize = Problem::kBlockSize;
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2), constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
TransposeC>; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1; constexpr index_t N0 = NPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = BlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = BlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
...@@ -3,19 +3,23 @@ ...@@ -3,19 +3,23 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <bool kPadA_, template <bool kPadM_,
bool kPadB_, bool kPadN_,
bool kPadC_, bool kPadK_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_> typename CLayout_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadA = kPadA_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
......
...@@ -10,114 +10,134 @@ ...@@ -10,114 +10,134 @@
namespace ck_tile { namespace ck_tile {
// fp16 // fp16
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 = using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K32 = using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16, WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2> template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>, WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2, 2,
swizzle_factor>>; swizzle_factor>>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma ...@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma
using BVecType = typename Impl::BVecType; using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma ...@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
Impl{}(c_vec, a_vec, b_vec); Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using BVecType = typename Impl::AVecType; using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
// swap A and B // swap A and B
Impl{}(c_vec, b_vec, a_vec); Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using BVecType = typename Impl::AVecType; using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence<2, 2>, sequence<2, 2>,
sequence<0, 2>>; sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void CK_TILE_DEVICE void operator()(CVecType& c_vec,
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
// swap A and B // swap A and B
Impl{}(c_vec, b_vec, a_vec); Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence<2, 2>, sequence<2, 2>,
sequence<0, 2>>; sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void CK_TILE_DEVICE void operator()(CVecType& c_vec,
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter], .template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]); .template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<0, 2>>; sequence<0, 2>>;
#endif #endif
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter], .template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]); .template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,12 +7,68 @@ ...@@ -7,12 +7,68 @@
namespace ck_tile { namespace ck_tile {
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum class WGAttrCtlEnum
{
Default_ = 0,
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16 // FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{ {
using ADataType = fp16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = fp16_t; using ADataType = fp16_t;
using CDataType = float; using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>; using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>; using BVecType = ext_vector_t<fp16_t, 4>;
...@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
else
{
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{ {
using ADataType = fp16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = fp16_t; using ADataType = fp16_t;
using CDataType = float; using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>; using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>; using BVecType = ext_vector_t<fp16_t, 4>;
...@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
// Bf16 // Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{ {
using ADataType = bf16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = bf16_t; using ADataType = bf16_t;
using CDataType = float; using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>; using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>; using BVecType = ext_vector_t<bf16_t, 4>;
...@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
else
{
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec) reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec) reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec, c_vec,
0, 0,
0, 0,
0); 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{ {
using ADataType = bf16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = bf16_t; using ADataType = bf16_t;
using CDataType = float; using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>; using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>; using BVecType = ext_vector_t<bf16_t, 4>;
...@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec) reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec) reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec, c_vec,
0, 0,
0, 0,
0); 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
// FP8 // FP8
template <typename AType_, typename BType_> template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{ {
using ADataType = AType_; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = BType_; using ADataType = AType_;
using CDataType = float; using BDataType = BType_;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 8>; using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>; using BVecType = ext_vector_t<BDataType, 8>;
...@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
}
}
else
{
#if defined(__gfx94__) #if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>) if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>) else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>) else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>) else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__) #elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float a_f32 = float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec) type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]); .template get_as<ADataType>()[number<k>{}]);
float b_f32 = float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec) type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]); .template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
// int8
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
#undef DISPATCH_MFMA_
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; ...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
// fp16 // fp16
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// clang-format on // clang-format on
} // namespace impl } // namespace impl
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,6 +14,11 @@ struct WarpGemmImpl ...@@ -14,6 +14,11 @@ struct WarpGemmImpl
static constexpr index_t kM = WarpGemmAttribute::kM; static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN; static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK; static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType; using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType; using BDataType = typename WarpGemmAttribute::BDataType;
...@@ -31,11 +36,21 @@ struct WarpGemmImpl ...@@ -31,11 +36,21 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>; using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>; using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{ {
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; return WarpGemmAttribute_::get_num_of_access();
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>; }
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
...@@ -44,18 +59,49 @@ struct WarpGemmImpl ...@@ -44,18 +59,49 @@ struct WarpGemmImpl
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0]; auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec); WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec); c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
} }
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
number<i_subk>,
bool_constant<post_nop_> = {}) const
{ {
CWarpTensor c; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>; const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>; auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
......
...@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs ...@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -93,7 +96,10 @@ struct Layernorm2dFwd ...@@ -93,7 +96,10 @@ struct Layernorm2dFwd
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
using Hargs = Layernorm2dFwdHostArgs; using Hargs = Layernorm2dFwdHostArgs;
...@@ -112,12 +118,15 @@ struct Layernorm2dFwd ...@@ -112,12 +118,15 @@ struct Layernorm2dFwd
hargs.epsilon, hargs.epsilon,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{ {
return (hargs.m + Block_M - 1) / Block_M; return dim3(integer_divide_ceil(hargs.m, Block_M));
} }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
...@@ -165,7 +174,7 @@ struct Layernorm2dFwd ...@@ -165,7 +174,7 @@ struct Layernorm2dFwd
return base_str; return base_str;
}(); }();
return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" + return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix; _SS_(Pipeline::name) + surfix;
...@@ -182,7 +191,7 @@ struct Layernorm2dFwd ...@@ -182,7 +191,7 @@ struct Layernorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -201,7 +210,7 @@ struct Layernorm2dFwd ...@@ -201,7 +210,7 @@ struct Layernorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual), static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.xr_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -250,7 +259,7 @@ struct Layernorm2dFwd ...@@ -250,7 +259,7 @@ struct Layernorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -266,7 +275,7 @@ struct Layernorm2dFwd ...@@ -266,7 +275,7 @@ struct Layernorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual), static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.yr_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
...@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
sequence<1, 1, 2, 2>, sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{}); sequence<0, 3, 0, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{ {
...@@ -44,9 +45,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -44,9 +45,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelford<P_>{}; return BlockWelford<P_>{};
} }
...@@ -54,9 +56,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -54,9 +56,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelfordSync<P_>{}; return BlockWelfordSync<P_>{};
} }
...@@ -64,9 +67,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -64,9 +67,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
return BlockWelfordCrossWarpSync<P_>{}; return BlockWelfordCrossWarpSync<P_>{};
} }
...@@ -76,13 +80,14 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -76,13 +80,14 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{ {
if constexpr(Problem::kNeedCrossWarpSync) if constexpr(Problem::kNeedCrossWarpSync)
{ {
using P_ = BlockWelfordProblem<typename Problem::XDataType, using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape>; typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using block_welford = BlockWelford<P_>; using block_welford = BlockWelford<P_>;
using x_block_tile = using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>( decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>())); MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile = using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>()); decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
......
...@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -87,12 +88,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -87,12 +88,9 @@ struct Layernorm2dFwdPipelineOnePass
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window( auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>()); y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_scale_window = make_tile_window(
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto x_scale = load_tile(x_scale_window);
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
...@@ -106,28 +104,37 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -106,28 +104,37 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window); const auto beta = load_tile(beta_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) + acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
type_convert<YResidualDataType>(x(idx));
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
} }
// compute welford each-thread->cross-lane->cross-warp // compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count); auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon)); if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
}
}, },
var); var);
...@@ -137,7 +144,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -137,7 +144,7 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std)); store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation // layernorm computation
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution()); auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) { sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -145,26 +152,15 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -145,26 +152,15 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]); const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]); auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
ln(idx) = ln_; ln(idx) = ln_;
}); });
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile(ln, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
ln(idx) = ln(idx) * xs_;
});
}
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{ {
Epilogue{}(y_window_, y_scale_window, ln, smem); Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
} }
else else
Epilogue{}(y_window_, ln); Epilogue{}(y_window_, ln);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment