"git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "1f4d8ee3b465af74b9c2b519601ba4c64d25938c"
Commit 6db81a11 authored by ThomasNing's avatar ThomasNing
Browse files

Address the comments

parent b2c7d774
...@@ -11,29 +11,29 @@ ...@@ -11,29 +11,29 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1 #define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V2 3 #define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT #ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
constexpr bool isDoubleSmemBuffer = false; constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr bool isDoubleSmemBuffer = false; constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr bool isDoubleSmemBuffer = true; constexpr bool DoubleSmemBuffer = true;
#else #else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value" #error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif #endif
......
...@@ -29,7 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -29,7 +29,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -43,7 +43,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -43,7 +43,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level // Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
...@@ -78,12 +78,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -78,12 +78,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, isDoubleSmemBuffer, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM, using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN, kPadN,
kPadK, kPadK,
isDoubleSmemBuffer, DoubleSmemBuffer,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -142,7 +141,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -142,7 +141,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if(has_hot_loop) if(has_hot_loop)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full) if(tail_num == ck_tile::TailNumber::Full)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
...@@ -217,72 +216,75 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -217,72 +216,75 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
} }
} }
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if constexpr(BaseGemmPipeline::PrefetchStages > 2) if(tail_num == ck_tile::TailNumber::Three)
{ {
if(tail_num == ck_tile::TailNumber::Two) Run(ck_tile::bool_constant<true>{},
{ ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
} }
#endif else
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{ {
Run(ck_tile::bool_constant<false>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
} }
}
#endif
else else
{ {
std::ostringstream err; // Tail number always Full - #PrefetchStages
err << "When there's no hot loop, this tail number \"" << tail_num if(tail_num == ck_tile::TailNumber::Full)
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages {
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; Run(ck_tile::bool_constant<false>{},
throw std::runtime_error(err.str()); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
} }
}
return ave_time; return ave_time;
} }
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[]) int run_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} }
else if(a_layout == "R" && b_layout == "C") else if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") else if(a_layout == "C" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "R") else if(a_layout == "C" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
} }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error(
"Unsupported data layout configuration for A,B and C tensors!");
}
} }
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -467,21 +467,24 @@ struct GemmKernel ...@@ -467,21 +467,24 @@ struct GemmKernel
* @param a_ptr input A pointer * @param a_ptr input A pointer
* @param b_ptr input B pointer * @param b_ptr input B pointer
* @param c_ptr output C pointer * @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param splitk_batch_offset When there are more than 1 batch needs to split the k.
* splitk_batch_offset stands for its the K from which batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
* *
* @tparam DstInMemOp Destination memory operation (default: set). * @tparam DstInMemOp Destination memory operation (default: set).
*/ */
template <memory_operation_enum DstInMemOp = memory_operation_enum::set> template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemmSinglePointer(const ADataType* a_ptr, CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr_0, void* smem_ptr_0,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset, const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
const index_t block_idx_n) const index_t block_idx_n)
{ {
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = const auto& gemm_tensor_views_tuple =
...@@ -521,22 +524,26 @@ struct GemmKernel ...@@ -521,22 +524,26 @@ struct GemmKernel
* @param a_ptr input A pointer * @param a_ptr input A pointer
* @param b_ptr input B pointer * @param b_ptr input B pointer
* @param c_ptr output C pointer * @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param splitk_batch_offset When there are more than 1 batch needs to split the k.
* splitk_batch_offset stands for its the K from which batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
* *
* @tparam DstInMemOp Destination memory operation (default: set). * @tparam DstInMemOp Destination memory operation (default: set).
*/ */
template <memory_operation_enum DstInMemOp = memory_operation_enum::set> template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemmDoublePointer(const ADataType* a_ptr, CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr_0, void* smem_ptr_0,
void* smem_ptr_1, void* smem_ptr_1,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset, const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
const index_t block_idx_n) const index_t block_idx_n)
{ {
// Create Gemm tensor views, pad views and tile windows // Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = const auto& gemm_tensor_views_tuple =
...@@ -590,41 +597,40 @@ struct GemmKernel ...@@ -590,41 +597,40 @@ struct GemmKernel
if(kargs.KBatch == 1) if(kargs.KBatch == 1)
{ {
if constexpr(GemmPipeline::isDoubleSmemBuffer == true) if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{ {
RunGemmDoublePointer(a_ptr, RunGemm2LDS(a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
smem_ptr_0, smem_ptr_0,
smem_ptr_1, smem_ptr_1,
kargs, kargs,
splitk_batch_offset, splitk_batch_offset,
i_m, i_m,
i_n); i_n);
} }
else else
{ {
RunGemmSinglePointer( RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
} }
} }
else else
{ {
if constexpr(GemmPipeline::isDoubleSmemBuffer == true) if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{ {
RunGemmDoublePointer<memory_operation_enum::atomic_add>(a_ptr, RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
smem_ptr_0, smem_ptr_0,
smem_ptr_1, smem_ptr_1,
kargs, kargs,
splitk_batch_offset, splitk_batch_offset,
i_m, i_m,
i_n); i_n);
} }
else else
{ {
RunGemmSinglePointer<memory_operation_enum::atomic_add>( RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
} }
} }
......
...@@ -70,7 +70,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -70,7 +70,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...@@ -15,7 +15,7 @@ namespace ck_tile { ...@@ -15,7 +15,7 @@ namespace ck_tile {
template <typename Problem> template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4 struct BaseGemmPipelineAgBgCrCompV4
{ {
static constexpr index_t PrefetchStages = 3; static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
...@@ -26,11 +26,23 @@ struct BaseGemmPipelineAgBgCrCompV4 ...@@ -26,11 +26,23 @@ struct BaseGemmPipelineAgBgCrCompV4
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{ {
ignore = num_loop; if(num_loop % PrefetchStages == 1)
return TailNumber::Two; {
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
} }
}; };
// Compute optimized pipeline version 4
// The difference between this pipeline and compute version 3 is it has two LDS window that will use
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading
// time.
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{ {
...@@ -65,7 +77,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -65,7 +77,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
...@@ -128,13 +140,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -128,13 +140,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
? B_LDS_Read_Inst_Num ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2; : B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
ignore = i; ignore = i;
...@@ -170,7 +179,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -170,7 +179,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> && std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!"); "Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major = constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>; std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
...@@ -476,7 +485,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -476,7 +485,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
} }
else if(TailNum == TailNumber::Two) else
{ {
// 2 // 2
{ {
...@@ -497,13 +506,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -497,13 +506,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
} }
else // when tail num is one
{
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
return c_block_tile; return c_block_tile;
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile { namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCRegV1 // Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// Default policy class should not be templated, put template on member functions instead // the same vector size implementation, SmemSize, Global memory tile distiribution as the
struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy // UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy : public UniversalGemmBasePolicy
{ {
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
}
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
elements_per_thread % (4 / sizeof(DataType)) == 0)
{
return (4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
elements_per_thread % (2 / sizeof(DataType)) == 0)
{
return (2 / sizeof(DataType));
}
else
{
return 1;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
...@@ -125,8 +31,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -125,8 +31,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
{ {
if constexpr(TransposeC) if constexpr(TransposeC)
{ {
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY; constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
...@@ -136,7 +40,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -136,7 +40,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
} }
else else
{ {
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
} }
} }
...@@ -145,13 +48,10 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -145,13 +48,10 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
{ {
if constexpr(TransposeC) if constexpr(TransposeC)
{ {
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
} }
else else
{ {
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY; constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
...@@ -189,41 +89,43 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -189,41 +89,43 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / KPack>{}, number<kMPerBlock>{}, number<KPack>{}),
make_tuple(number<kMPerBlock * 8>{}, number<8>{}, number<1>{}), make_tuple(number<kMPerBlock * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{}, number<KPack>{},
number<1>{}); number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0, a_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}), make_tuple(
make_merge_transform(make_tuple(number<kKPerBlock>{} / 8, number<8>{}))), make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock>{} / KPack, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc; return a_lds_block_desc;
} }
// 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / KPack>{}, number<kNPerBlock>{}, number<KPack>{}),
make_tuple(number<(kNPerBlock)*8>{}, number<8>{}, number<1>{}), make_tuple(number<(kNPerBlock)*KPack>{}, number<KPack>{}, number<1>{}),
number<8>{}, number<KPack>{},
number<1>{}); number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0, b_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}), make_tuple(
make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))), make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -259,112 +161,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -259,112 +161,6 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
return smem_size_a + smem_size_b; return smem_size_a + smem_size_b;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
......
...@@ -163,7 +163,7 @@ struct UniversalGemmPipelineProblem ...@@ -163,7 +163,7 @@ struct UniversalGemmPipelineProblem
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool isDoubleSmemBuffer = Traits::isDoubleSmemBuffer; static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy struct UniversalGemmBasePolicy
struct UniversalGemmPipelineAgBgCrPolicy
{ {
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -19,15 +18,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -19,15 +18,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile> template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{ {
...@@ -98,18 +88,117 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -98,18 +88,117 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
/** template <typename Problem>
* @brief Get the vector store size for C tensor. CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
* {
* @tparam Problem - Gemm pipeline problem class. return Problem::TransposeC;
* }
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would template <typename Problem>
* be the number of consecutive elements in contiguous C dimension hold by CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
* single thread. {
* using ALayout = remove_cvref_t<typename Problem::ALayout>;
* @return The vector store size for C tensor.
*/ constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
};
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy
{
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
...@@ -125,8 +214,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -125,8 +214,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
if constexpr(TransposeC) if constexpr(TransposeC)
{ {
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY; constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
...@@ -136,7 +223,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -136,7 +223,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
else else
{ {
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
} }
} }
...@@ -145,13 +231,10 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -145,13 +231,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
if constexpr(TransposeC) if constexpr(TransposeC)
{ {
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
} }
else else
{ {
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY; constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
...@@ -425,16 +508,20 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -425,16 +508,20 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a =
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b =
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b; return smem_size_b;
} }
...@@ -443,116 +530,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -443,116 +530,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
template <typename Problem> return smem_size_a + smem_size_b;
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
} }
template <typename Problem> template <typename Problem>
......
...@@ -10,7 +10,6 @@ namespace ck_tile { ...@@ -10,7 +10,6 @@ namespace ck_tile {
template <bool kPadM_, template <bool kPadM_,
bool kPadN_, bool kPadN_,
bool kPadK_, bool kPadK_,
bool isDoubleSmemBuffer_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_> typename CLayout_>
...@@ -20,8 +19,6 @@ struct TileGemmTraits ...@@ -20,8 +19,6 @@ struct TileGemmTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
// TODO this can't be hardcoded here! Should be in policy! // TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16; static constexpr int _VectorSize = 16;
...@@ -35,7 +32,7 @@ struct TileGemmTraits ...@@ -35,7 +32,7 @@ struct TileGemmTraits
template <bool kPadM_, template <bool kPadM_,
bool kPadN_, bool kPadN_,
bool kPadK_, bool kPadK_,
bool isDoubleSmemBuffer_, bool DoubleSmemBuffer_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_, typename CLayout_,
...@@ -46,7 +43,7 @@ struct TileGemmUniversalTraits ...@@ -46,7 +43,7 @@ struct TileGemmUniversalTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment