"...composable_kernel_rocm.git" did not exist on "c8f3acf9c015fbbba11456df5e829e0e7f57eaf2"
Unverified Commit 781005a5 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a11cf2c6 39dc25a9
...@@ -70,9 +70,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -70,9 +70,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
...@@ -103,4 +101,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -103,4 +101,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -217,39 +217,3 @@ int run_gemm_example_with_layouts(int argc, ...@@ -217,39 +217,3 @@ int run_gemm_example_with_layouts(int argc,
return pass; return pass;
} }
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
...@@ -28,8 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -28,8 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#endif
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -48,6 +48,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -48,6 +48,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// =============================================== // ===============================================
...@@ -62,7 +64,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -62,7 +64,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
...@@ -85,14 +88,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -85,14 +88,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, GemmUniversalTraits,
scheduler, scheduler,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>; using GemmPipeline =
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
auto kargs = Kernel::MakeKernelArgs(args); using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
...@@ -117,6 +121,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -117,6 +121,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if(has_hot_loop) if(has_hot_loop)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
...@@ -177,6 +196,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -177,6 +196,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
} }
} }
#endif
} }
else else
{ {
...@@ -201,4 +221,38 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -201,4 +221,38 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -72,9 +72,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -72,9 +72,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
......
...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[]) ...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride") .insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride") .insert("batch_stride_b", "16384", "Batch B stride")
......
...@@ -3,13 +3,6 @@ ...@@ -3,13 +3,6 @@
#pragma once #pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K, auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch, const ck_tile::index_t kbatch,
const float max_accumulated_value) const float max_accumulated_value)
...@@ -113,16 +106,56 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -113,16 +106,56 @@ int run_batched_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); using namespace ck_tile::literals;
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor( std::size_t col,
batch_count, M, K, stride_A, batch_stride_A, is_row_major(a_layout))); std::size_t stride,
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor( std::size_t batch_stride,
batch_count, K, N, stride_B, batch_stride_B, is_row_major(b_layout))); auto layout) {
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor( if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
batch_count, M, N, stride_C, batch_stride_C, is_row_major(c_layout))); {
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
...@@ -158,8 +191,8 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -158,8 +191,8 @@ int run_batched_gemm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1}); const auto b_n_k = b_k_n.transpose({0, 2, 1});
...@@ -183,8 +216,8 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -183,8 +216,8 @@ int run_batched_gemm_example_with_layouts(int argc,
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(ck_tile::host_tensor_descriptor( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
...@@ -268,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -268,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") // if(a_layout == "R" && b_layout == "R")
{ // {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else if(a_layout == "R" && b_layout == "C") if(a_layout == "R" && b_layout == "C")
{ {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
......
...@@ -88,12 +88,9 @@ using CodegenPipelineProblem = ...@@ -88,12 +88,9 @@ using CodegenPipelineProblem =
CodegenGemmShape, CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>; CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>, ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
......
...@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[]) ...@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("warmup", "10", "number of iterations before benchmark the kernel.")
......
...@@ -135,12 +135,9 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -135,12 +135,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i]; const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i]; const ck_tile::index_t K = Ks[i];
stride_As[i] = stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout)); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Bs[i] = stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] =
ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>( a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
...@@ -229,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[]) ...@@ -229,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[])
{ {
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "R" && b_layout == "R") // else if(a_layout == "R" && b_layout == "R")
{ // {
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...@@ -53,6 +54,7 @@ ...@@ -53,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
/**
* @brief Enumeration describing static tile distribution patterns.
*
*/
enum struct tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked,
/**
* @brief Warp raked pattern.
*
*/
warp_raked,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
};
struct TileDistributionEncodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
};
} // namespace ck_tile
...@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t) ...@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
using Idx = number<tuple<Ts...>::size() - i - 1>; using Idx = number<tuple<Ts...>::size() - i - 1>;
return t.at(Idx{}); return t.at(Idx{});
}, },
number<tuple<Ts...>::size()()>{}); number<tuple<Ts...>::size()>{});
} }
// Reduce tuple values in specific range using Function // Reduce tuple values in specific range using Function
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -18,8 +18,17 @@ ...@@ -18,8 +18,17 @@
namespace ck_tile { namespace ck_tile {
// Note: this tile window do not support single issue /**
// you need to use tile_window_linear structure for this purpose * @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename StaticTileDistribution_, typename StaticTileDistribution_,
...@@ -1009,6 +1018,14 @@ CK_TILE_DEVICE void move_tile_window( ...@@ -1009,6 +1018,14 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step); window.move(step);
} }
/**
* @brief This class provides description of tile windowed view on the device memory.
*
* @note This class does not provide any functions to read or modify device memory.
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
*/
template <typename BottomTensorView_, typename WindowLengths_> template <typename BottomTensorView_, typename WindowLengths_>
struct tile_window_with_static_lengths struct tile_window_with_static_lengths
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace ck_tile {
namespace detail {
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
const InTensor& in_tensor)
{
constexpr auto I0 = number<0>{};
static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
"Data type for InTensor and OutTensor must be the same!");
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
// For swapped Hs tile case I need only get_rh_minor_to_y
// since rh_major are already swapped due to swapped Hs.
constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<index_t, index_t> rh_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_minor_to_y_(rh_minor) = i;
});
return rh_minor_to_y_;
};
// In swapped Hs case <Y,X> -> <X,Y> tile
// we have same rh_major, but reversed rh_minor!
constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
// Is this really needed?? Should we have simple reverse here??
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
}
return y_dim_out_to_in_;
}();
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = space_filling_curve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::get_num_of_access();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}
} // namespace detail
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InTileDistr = typename InTensor::StaticTileDistribution;
using OutTileDistr = typename OutTensor::StaticTileDistribution;
using InDstrEncode = typename InTileDistr::DstrEncode;
using OutDstrEncode = typename OutTileDistr::DstrEncode;
using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
// Ys:
constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
// type convert
const auto in_tmp = [&]() {
if constexpr(std::is_same_v<OutDataType, InDataType>)
{
return in;
}
else
{
return tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
}
}();
// Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
// we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
// Any condition on Ps ??
// InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
// InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
{
detail::transpose_tile2d_impl_in_thread(out, in_tmp);
}
else
{
static_assert(false, "Provided tensors could not be transposed!");
}
}
} // namespace ck_tile
...@@ -80,7 +80,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -80,7 +80,7 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread; static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack; static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
}; };
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -69,6 +68,7 @@ struct GemmKernel ...@@ -69,6 +68,7 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>(); static constexpr auto I0 = number<0>();
...@@ -168,6 +168,7 @@ struct GemmKernel ...@@ -168,6 +168,7 @@ struct GemmKernel
{ {
if(kargs.KBatch != 1) if(kargs.KBatch != 1)
{ {
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false; return false;
} }
} }
...@@ -176,10 +177,14 @@ struct GemmKernel ...@@ -176,10 +177,14 @@ struct GemmKernel
{ {
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeA != 0) if(kargs.K % GemmPipeline::VectorSizeA != 0)
{ {
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -187,10 +192,14 @@ struct GemmKernel ...@@ -187,10 +192,14 @@ struct GemmKernel
{ {
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeA != 0) if(kargs.M % GemmPipeline::VectorSizeA != 0)
{ {
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -199,10 +208,14 @@ struct GemmKernel ...@@ -199,10 +208,14 @@ struct GemmKernel
{ {
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeB != 0) if(kargs.N % GemmPipeline::VectorSizeB != 0)
{ {
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -210,10 +223,14 @@ struct GemmKernel ...@@ -210,10 +223,14 @@ struct GemmKernel
{ {
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
{ {
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeB != 0) if(kargs.K % GemmPipeline::VectorSizeB != 0)
{ {
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -222,10 +239,14 @@ struct GemmKernel ...@@ -222,10 +239,14 @@ struct GemmKernel
{ {
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{ {
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeC != 0) if(kargs.N % GemmPipeline::VectorSizeC != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -233,10 +254,14 @@ struct GemmKernel ...@@ -233,10 +254,14 @@ struct GemmKernel
{ {
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{ {
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeC != 0) if(kargs.M % GemmPipeline::VectorSizeC != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
} }
} }
...@@ -250,6 +275,14 @@ struct GemmKernel ...@@ -250,6 +275,14 @@ struct GemmKernel
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset) const SplitKBatchOffset& splitk_batch_offset)
{ {
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -264,9 +297,9 @@ struct GemmKernel ...@@ -264,9 +297,9 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(1, kargs.stride_A), make_tuple(kargs.stride_A, 1),
number<1>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -276,9 +309,9 @@ struct GemmKernel ...@@ -276,9 +309,9 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_ptr, b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(1, kargs.stride_B), make_tuple(kargs.stride_B, 1),
number<1>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
} }
else else
...@@ -292,6 +325,7 @@ struct GemmKernel ...@@ -292,6 +325,7 @@ struct GemmKernel
} }
}(); }();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() { const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -331,9 +365,9 @@ struct GemmKernel ...@@ -331,9 +365,9 @@ struct GemmKernel
else else
{ {
return pad_tensor_view(a_tensor_view, return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::KPerBlock>{}), number<TilePartitioner::MPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<false, GemmPipeline::kPadM>{});
} }
}(); }();
...@@ -349,12 +383,13 @@ struct GemmKernel ...@@ -349,12 +383,13 @@ struct GemmKernel
else else
{ {
return pad_tensor_view(b_tensor_view, return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::KPerBlock>{}), number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadN, false>{}); sequence<false, GemmPipeline::kPadN>{});
} }
}(); }();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() { const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2); const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
...@@ -380,20 +415,45 @@ struct GemmKernel ...@@ -380,20 +415,45 @@ struct GemmKernel
CK_TILE_DEVICE static auto CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{ {
const auto& a_pad_view = views.at(I0); const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window( const auto& b_pad_view = views.at(I1);
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
const auto& c_pad_view = views.at(I2); const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window(
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}), make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n}); {i_m, i_n});
......
...@@ -50,7 +50,6 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -50,7 +50,6 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using GemmKernelArgs = typename Base::GemmKernelArgs; using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t KBatch = 1;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
...@@ -124,7 +123,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -124,7 +123,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
stride_a, stride_a,
stride_b, stride_b,
stride_c, stride_c,
KBatch}; gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase ...@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow> template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{ {
load_tile(dst_block_tile, dram_tile_window); load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock}); move_tile_window(dram_tile_window, dram_tile_window_step);
} }
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction> template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
...@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase ...@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const const ALdsTensorView& a_lds_block_view) const
{ {
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window = auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block_view, a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0}); a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
...@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase ...@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const const BLdsTensorView& b_lds_block_view) const
{ {
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
auto b_copy_dram_window = auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window_tmp.get_window_origin(), b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block_view, b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_lds_gemm_window = make_tile_window( auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
...@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3 ...@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1 // LocalPreFillStages: 1
// LocalPreFetchStages: 1 // LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1 // LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
...@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
...@@ -82,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -82,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
...@@ -248,11 +250,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -248,11 +250,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && constexpr bool is_a_col_major =
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
...@@ -287,23 +300,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -287,23 +300,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile a_block_tile; ABlockTile a_block_tile;
BBlockTile b_block_tile; BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
// Gemm pipeline start // Gemm pipeline start
// prefetch // prefetch
// global read 0 // global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds(); block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
...@@ -318,11 +359,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -318,11 +359,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{ {
block_sync_lds(); block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Policy::template MakeShuffledARegTileDistribution<Problem>());
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......
...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -133,7 +133,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -133,7 +133,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment