Unverified Commit f49b595d authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

[CK TILE] Add gemm compute pipeline v3 (#1661)



* [CK TILE] Add gemm compute pipeline v3

* Enable universal gemm compute pipeline.

* Rename example and add compute pipeline.

* Introduce ag bg cr pipeline impl base.

* Refactor to reuse code.

* Cleaning

* Formatting.

---------
Co-authored-by: default avatarAdam Osewski <19374865+aosewski@users.noreply.github.com>
Co-authored-by: default avatarAdam Osewski <Adam.Osewski@amd.com>
parent e7b62864
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -14,10 +14,17 @@ ...@@ -14,10 +14,17 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
#if 1 #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32; constexpr ck_tile::index_t N_Tile = 32;
...@@ -30,7 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -30,7 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#else
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -63,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -63,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
...@@ -77,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -77,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave, ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......
...@@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}), using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}), static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}), static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!"); "Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}), static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!"); "Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
...@@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr ...@@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler; static constexpr auto Scheduler = Traits::Scheduler;
using I0 = number<0>;
using I1 = number<1>;
private: private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits> template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl struct BlockGemmImpl
...@@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr ...@@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window, const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!"); static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
static_assert(std::is_same_v<typename GemmTraits::ADataType, std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr ...@@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>, statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>, statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
...@@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
...@@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr
// hot loop: // hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
...@@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr ...@@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits> struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{ {
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>, statically_indexed_array<typename GemmTraits::AWarpTile, KIterPerWarp>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_tiles_; a_warp_tiles_;
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>, statically_indexed_array<typename GemmTraits::BWarpTile, KIterPerWarp>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_tiles_; b_warp_tiles_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow> template <typename ASmemBlockWindow, typename BSmemBlockWindow>
...@@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr ...@@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType, static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
typename ASmemBlockWindow::DataType> && std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr
AWarpWindow{}.get_window_lengths(), AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"); "AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array< statically_indexed_array<statically_indexed_array<AWarpWindow, KIterPerWarp>,
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>, MIterPerWarp>
GemmTraits::MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr
BWarpWindow{}.get_window_lengths(), BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"); "BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array< statically_indexed_array<statically_indexed_array<BWarpWindow, KIterPerWarp>,
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>, NIterPerWarp>
GemmTraits::NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window! // TODO: I don't have to move 0,0 window!
...@@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter), move_tile_window(b_warp_windows(nIter)(kIter),
...@@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
}); });
...@@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr ...@@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr
[[maybe_unused]] const ASmemBlockWindow& a_block_window, [[maybe_unused]] const ASmemBlockWindow& a_block_window,
[[maybe_unused]] const BSmemBlockWindow& b_block_window) [[maybe_unused]] const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr ...@@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kIter], a_warp_tiles_[mIter][kIter],
b_warp_tiles_[nIter][kIter]); b_warp_tiles_[nIter][kIter]);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>, statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_tiles_; a_warp_tiles_;
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>, statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_tiles_; b_warp_tiles_;
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow> template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
...@@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr ...@@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType, static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
typename ASmemBlockWindow::DataType> && std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop}, multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr
"AWarpWindow lengths must be equal to AWarpTile lengths!"); "AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>, statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop}, multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr
"BWarpWindow lengths must be equal to BWarpTile lengths!"); "BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>, statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
...@@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
...@@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr ...@@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window // TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
}); });
...@@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window, const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr
} }
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr ...@@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr
// penalty // penalty
if constexpr(kIter.value == KRepeat - 1 && if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 && kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == GemmTraits::MIterPerWarp - 1 && mIter.value == MIterPerWarp - 1 &&
nIter.value == GemmTraits::NIterPerWarp - 1) nIter.value == NIterPerWarp - 1)
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
block_sync_lds(); block_sync_lds();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kInnerIter], a_warp_tiles_[mIter][kInnerIter],
b_warp_tiles_[nIter][kInnerIter]); b_warp_tiles_[nIter][kInnerIter]);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmPipelineAgBgCrImplBase
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView>
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
{
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView>
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
{
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV3
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
} // namespace ck_tile
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrMem<Problem>; using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>; using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base::PrefetchStages; using Base::PrefetchStages;
CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl struct PipelineImpl : public PipelineImplBase
{ {
}; };
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{ {
template <typename DstBlockTile, typename SrcTileWindow> using Base = PipelineImplBase;
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop, template <bool HasHotLoop,
TailNumber TailNum, TailNumber TailNum,
...@@ -185,66 +162,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -185,66 +162,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); " or KPerBlock!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
// A tile in LDS // A/B tiles in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); // With c++20 could simplify to below line.
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); // Currently get error: captured structured bindings are a C++20 extension
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
// TODO: LDS alignment should come from Policy! auto& a_lds_block = ab_lds_blocks.at(I0{});
constexpr index_t a_lds_block_space_size_aligned = auto& b_lds_block = ab_lds_blocks.at(I1{});
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = // A LDS tile for block GEMM
make_tile_window(a_lds_block, auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), auto& a_copy_dram_window = a_windows.at(I0{});
{0, 0}, auto& a_copy_lds_window = a_windows.at(I1{});
a_copy_dram_window.get_tile_distribution()); auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B DRAM tile window for load
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM // B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window( auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
}); });
// main body // main body
...@@ -295,19 +244,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -295,19 +244,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
LocalPrefill( Base::LocalPrefill(
a_copy_lds_window, a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func); a_element_func);
LocalPrefill( Base::LocalPrefill(
b_copy_lds_window, b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func); b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -323,12 +272,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -323,12 +272,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
LocalPrefill(a_copy_lds_window, Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}), a_block_tiles.get(number<prefetch_idx>{}),
a_element_func); a_element_func);
LocalPrefill(b_copy_lds_window, Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}), b_block_tiles.get(number<prefetch_idx>{}),
b_element_func); b_element_func);
}); });
block_sync_lds(); block_sync_lds();
...@@ -376,24 +325,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -376,24 +325,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}; };
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
{ {
template <typename DstBlockTile, typename SrcTileWindow> using Base = PipelineImplBase;
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop, template <bool HasHotLoop,
TailNumber TailNum, TailNumber TailNum,
...@@ -415,66 +349,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -415,66 +349,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); " or KPerBlock!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
// A tile in LDS // A/B tiles in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); // With c++20 could simplify to below line.
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); // Currently get error: captured structured bindings are a C++20 extension
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
// TODO: LDS alignment should come from Policy! auto& a_lds_block = ab_lds_blocks.at(I0{});
constexpr index_t a_lds_block_space_size_aligned = auto& b_lds_block = ab_lds_blocks.at(I1{});
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = // A LDS tile for block GEMM
make_tile_window(a_lds_block, auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), auto& a_copy_dram_window = a_windows.at(I0{});
{0, 0}, auto& a_copy_lds_window = a_windows.at(I1{});
a_copy_dram_window.get_tile_distribution()); auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B DRAM tile window for load
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM // B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window( auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -496,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -496,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
}); });
// main body // main body
...@@ -523,19 +429,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -523,19 +429,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
LocalPrefill( Base::LocalPrefill(
a_copy_lds_window, a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func); a_element_func);
LocalPrefill( Base::LocalPrefill(
b_copy_lds_window, b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func); b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -548,12 +454,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -548,12 +454,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
LocalPrefill(a_copy_lds_window, Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}), a_block_tiles.get(number<prefetch_idx>{}),
a_element_func); a_element_func);
LocalPrefill(b_copy_lds_window, Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}), b_block_tiles.get(number<prefetch_idx>{}),
b_element_func); b_element_func);
}); });
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment