Unverified Commit b6bcd76d authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)

* Block universal gemm.

* Universal block gemm with interwave scheduler - draft.

* Refactoring

* Move a/b_warp_tiles into BlockGemmImpl
* set BlockGemmImpl as a class member

* Change tile size for more suitable to memory bound cases.

* Introduce kKPerThread to WarpGemm

* Add documentation comment.

* Fix Interwave scheduler block gemm.

* Add compute/memory friendly tile configuration.

* Clean

* New tile configurations in gemm mem example.

* Add more static checks and fix loop order in block gemm.

* Add more static checks and use warp gemm mfma dispatcher.

* Add default scheduler block gemm.

* Remove logging in example.
parent 440e28b0
...@@ -261,7 +261,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -261,7 +261,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel) if(config.time_kernel)
{ {
ave_time = ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -17,9 +17,24 @@ ...@@ -17,9 +17,24 @@
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// ToDo: This will be modified by the codegen code later. #if 1
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#else
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -28,12 +43,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -28,12 +43,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false;
constexpr bool kPadM = true; constexpr bool kPadN = false;
constexpr bool kPadN = true; constexpr bool kPadK = false;
constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -174,8 +189,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -174,8 +189,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
std::ostringstream err; std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__ << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< ", in function: " << __func__; << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
......
...@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time = gemm_calc<ALayout, BLayout, CLayout>( float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl; << std::endl;
...@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc, ...@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types // TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
...@@ -202,14 +199,15 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -202,14 +199,15 @@ int run_gemm_example(int argc, char* argv[])
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
{ // work. else if(a_layout == "C" && b_layout == "C")
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); // {
} // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
else if(a_layout == "C" && b_layout == "R") // }
{ // else if(a_layout == "C" && b_layout == "R")
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); // {
} // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
......
...@@ -247,7 +247,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -247,7 +247,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
constexpr auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile(); auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
...@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
...@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
...@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}); });
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}; };
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
{ {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
} }
else if constexpr(TailNum == TailNumber::Two) else if constexpr(TailNum == TailNumber::Two)
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
enum struct GemmPipelineScheduler enum struct GemmPipelineScheduler
{ {
Default,
Intrawave, Intrawave,
Interwave, Interwave,
}; };
...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch ...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{ {
switch(s) switch(s)
{ {
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break; case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break; case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
default: os << ""; default: os << "";
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
...@@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
...@@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t M0 = BlockSize / get_warp_size(); constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0); constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
...@@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
...@@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t N0 = BlockSize / get_warp_size(); constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0); constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
...@@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{};
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{}; constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase ...@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = GemmTraits::kPadN; static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK; static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -24,6 +24,7 @@ struct WarpGemmAtrributeMfma ...@@ -24,6 +24,7 @@ struct WarpGemmAtrributeMfma
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
...@@ -89,6 +90,7 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -89,6 +90,7 @@ struct WarpGemmAtrributeMfmaIterateK
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
...@@ -200,6 +202,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -200,6 +202,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
...@@ -263,6 +266,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -263,6 +266,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
...@@ -333,6 +337,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -333,6 +337,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
...@@ -447,6 +452,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -447,6 +452,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
...@@ -586,6 +592,7 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -586,6 +592,7 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,6 +14,11 @@ struct WarpGemmImpl ...@@ -14,6 +14,11 @@ struct WarpGemmImpl
static constexpr index_t kM = WarpGemmAttribute::kM; static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN; static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK; static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType; using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType; using BDataType = typename WarpGemmAttribute::BDataType;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment