Unverified Commit 5d671a5f authored by Thomas Ning's avatar Thomas Ning Committed by GitHub
Browse files

CK Tile GEMM CICD fixed & register block method refactor (#1776)

* refactor the block_gemm_areg_breg_creg_v1 and add the v2 policy with 2x2 warp gemm

* Finished the 2x2 warp gemm policy and the block selection mechanism

* Clang format

* address poyen's comment

* Address feedbacks

* Fixed the compilation issue

* Change the function name
parent 0b8f117f
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
......
...@@ -8,6 +8,27 @@ ...@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
......
...@@ -9,18 +9,9 @@ ...@@ -9,18 +9,9 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
...@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using GemmPipelineProblem =
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -89,26 +79,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -89,26 +79,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) BDataType,
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< AccDataType,
#endif GemmShape,
ck_tile::UniversalGemmPipelineProblem<ADataType, Traits,
BDataType, scheduler,
AccDataType, has_hot_loop_v,
GemmShape, tail_number_v>;
Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
ck_tile::GemmPipelineScheduler::Interwave, using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) auto kargs = Kernel::MakeKernelArgs(args);
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
......
...@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1 ...@@ -21,35 +21,20 @@ struct BlockGemmARegBRegCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
// C += A * B static constexpr index_t NPerBlock = BlockGemmShape::kN;
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor> static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
const ABlockTensor& a_block_tensor, using WG = remove_cvref_t<decltype(config.template at<0>())>;
const BBlockTensor& b_block_tensor) const static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding = constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto b_block_outer_dstr_encoding = constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<0, 1>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 1>>, tuple<sequence<1, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( return c_block_dstr_encode;
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); }
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // C += A * B
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
// check ABC-block-distribution // check ABC-block-distribution
static_assert( static_assert(
...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1 ...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
......
...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>; using I0 = number<0>;
using I2 = number<2>; using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
......
...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
...@@ -126,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -126,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = BlockGemm();
// Acc register tile // Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
......
...@@ -12,8 +12,11 @@ namespace ck_tile { ...@@ -12,8 +12,11 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = false; static constexpr bool TransposeC = true;
#if 0 #if 0
// 2d // 2d
...@@ -491,10 +494,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -491,10 +494,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float; using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
......
...@@ -11,7 +11,6 @@ namespace ck_tile { ...@@ -11,7 +11,6 @@ namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment