Commit 8bd49370 authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactoring & Move Layout info to pipeline problem.

parent d3689b06
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstring> #include <cstring>
...@@ -11,6 +10,11 @@ ...@@ -11,6 +10,11 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[]) ...@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
...@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
args.p_c, args.p_c,
args.epsilon,
args.M, args.M,
args.N, args.N,
args.K, args.K,
...@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf, ...@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
return -1; // Or handle the error appropriately return -1; // Or handle the error appropriately
} }
float epsilon = arg_parser.get_float("e");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
...@@ -107,69 +107,37 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf, ...@@ -107,69 +107,37 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args; gemm_basic_args args;
args.p_a = a_buf.GetDeviceBuffer(); args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer(); args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer(); args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon; args.kbatch = batch_size;
args.kbatch = batch_size; args.M = M;
args.M = M; args.N = N;
args.N = N; args.K = K;
args.K = K;
auto f_get_default_stride = [](std::size_t row,
// Only set stride_M and stride_N if they are non-zero and not equal to K. std::size_t col,
if(stride_a != 0) std::size_t stride,
{ auto layout) {
args.stride_A = stride_a; if(stride == 0)
} {
else // give a chance if stride is zero, return a default packed stride
{ if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
args.stride_A = [&]() {
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return K;
}
}();
}
if(stride_b != 0)
{
args.stride_B = stride_b;
}
else
{
args.stride_B = [&]() {
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::ColumnMajor>)
{ {
return N; return col;
} }
else else
{ {
return K; return row;
} }
}(); }
} else
return stride;
};
if(stride_c != 0) args.stride_A = f_get_default_stride(M, K, stride_a, LayoutA{});
{ args.stride_B = f_get_default_stride(K, N, stride_b, LayoutB{});
args.stride_C = stride_c; args.stride_C = f_get_default_stride(M, N, stride_c, LayoutC{});
}
else
{
args.stride_C = [&]() {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return N;
}
}();
}
float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>( float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>(
args, ck_tile::stream_config{nullptr, true}); args, ck_tile::stream_config{nullptr, true});
...@@ -197,30 +165,57 @@ int main(int argc, char* argv[]) ...@@ -197,30 +165,57 @@ int main(int argc, char* argv[])
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N). ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor; ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor; ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
std::vector<int> a_dimensions = using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K} using namespace ck_tile::literals;
: std::vector<int>{K, M};
std::vector<int> b_dimensions = auto f_host_tensor_descriptor =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
? std::vector<int>{N, K} if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
: std::vector<int>{K, N}; {
std::vector<int> c_dimensions = return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>) }
? std::vector<int>{M, N} else
: std::vector<int>{N, M}; {
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
ck_tile::HostTensor<ADataType> a_host(a_dimensions); }
ck_tile::HostTensor<BDataType> b_host(b_dimensions); };
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions); auto f_get_default_stride = [](std::size_t row,
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions); std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_host(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_host(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_host_ref(f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::HostTensor<CDataType> c_host_dev(f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
...@@ -259,6 +254,9 @@ int main(int argc, char* argv[]) ...@@ -259,6 +254,9 @@ int main(int argc, char* argv[])
BDataType, BDataType,
AccDataType, AccDataType,
CodegenGemmShape, CodegenGemmShape,
ALayout,
BLayout,
CLayout,
kPadA, kPadA,
kPadB, kPadB,
kPadC>; kPadC>;
...@@ -266,9 +264,9 @@ int main(int argc, char* argv[]) ...@@ -266,9 +264,9 @@ int main(int argc, char* argv[])
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t, invoke_gemm<ck_tile::half_t,
matrix_a_layout, ALayout,
matrix_b_layout, BLayout,
matrix_c_layout, CLayout,
CodegenPipelineProblem, CodegenPipelineProblem,
CodegenGemmPipeline, CodegenGemmPipeline,
CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser); CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser);
...@@ -280,17 +278,12 @@ int main(int argc, char* argv[]) ...@@ -280,17 +278,12 @@ int main(int argc, char* argv[])
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
// ToDo: Will Add the Element Op (bias) verification in the future. // ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<ADataType, ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
BDataType, a_host, b_host, c_host_ref);
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref); pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail") std::cout << "The CPU verification result is:" << (pass_cpu ? "correct" : "fail")
<< std::flush; << std::flush;
} }
...@@ -298,57 +291,19 @@ int main(int argc, char* argv[]) ...@@ -298,57 +291,19 @@ int main(int argc, char* argv[])
if(arg_parser.get_int("v") == 2) if(arg_parser.get_int("v") == 2)
{ {
ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); ck_tile::HostTensor<CDataType> c_host_gpu_ref(
ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0)
{
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_a = M;
}
else
{
stride_a = K;
}
}
if(stride_b == 0)
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b = N;
}
else
{
stride_b = K;
}
}
if(stride_c == 0)
{
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_c = M;
}
else
{
stride_c = N;
}
}
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
c_gpu_buf.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); a_buf, b_buf, c_gpu_buf, M, N, K, stride_A, stride_B, stride_C);
c_buf.FromDevice(c_host_gpu_ref.data()); c_gpu_buf.FromDevice(c_host_gpu_ref.data());
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail") std::cout << "The GPU verification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush; << std::flush;
} }
......
...@@ -4,12 +4,10 @@ ...@@ -4,12 +4,10 @@
#pragma once #pragma once
#include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
...@@ -58,7 +56,6 @@ struct gemm_basic_args ...@@ -58,7 +56,6 @@ struct gemm_basic_args
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
void* p_c; void* p_c;
float epsilon;
ck_tile::index_t kbatch; ck_tile::index_t kbatch;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstring> #include <cstring>
...@@ -11,20 +10,24 @@ ...@@ -11,20 +10,24 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("b", "1", "batch size")
.insert("m", "1024", "m dimension") .insert("m", "3840", "m dimension")
.insert("n", "2048", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "64", "k dimension") .insert("k", "4096", "k dimension")
.insert("stride_a", "0", "Tensor A stride") .insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
...@@ -32,7 +35,7 @@ auto create_args(int argc, char* argv[]) ...@@ -32,7 +35,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename LayoutA, typename LayoutB, typename LayoutC> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// ToDo: This will be modified by the codegen code later. // ToDo: This will be modified by the codegen code later.
...@@ -62,139 +65,180 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -62,139 +65,180 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using PipelineProblem =
ck_tile::BlockGemmUniversalPipelineProblem<ADataType,
BDataType,
CDataType,
GemmShape,
kPadA,
kPadB,
kPadC,
ck_tile::BlockGemmPipelineScheduler::Intrawave>;
// The GemmPipeline should also come from the Codegen.
using GemmPipeline = ck_tile::BlockGemmPipelineAgBgCrMem<PipelineProblem>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
using Kernel =
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
template <typename DataType, typename LayoutA, typename LayoutB, typename LayoutC>
float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec"); using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
if(data_type != DataTypeTraits<DataType>::name)
{ using BaseGemmPipeline =
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got " ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
<< data_type << std::endl; BDataType,
return -1; // Or handle the error appropriately CDataType,
} GemmShape,
ALayout,
float epsilon = arg_parser.get_float("e"); BLayout,
ck_tile::index_t batch_size = arg_parser.get_int("b"); CLayout>>;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
ck_tile::index_t K = arg_parser.get_int("k"); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); float ave_time{0};
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
const auto Run = [&](const auto& kernel) {
using GemmKernel = ck_tile::remove_cvref_t<decltype(kernel)>;
auto kargs = GemmKernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = GemmKernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
gemm_basic_args args; ave_time = ck_tile::launch_kernel(
args.p_a = a_buf.GetDeviceBuffer(); s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(kernel, grids, blocks, 0, kargs));
args.p_b = b_buf.GetDeviceBuffer(); };
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon; #define RUN_KERNEL_(has_hot_loop_, tail_number_) \
args.kbatch = batch_size; using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< \
args.M = M; ck_tile::UniversalGemmPipelineProblem<ADataType, \
args.N = N; BDataType, \
args.K = K; CDataType, \
GemmShape, \
// Only set stride_M and stride_N if they are non-zero and not equal to K. ALayout, \
if(stride_a != 0) BLayout, \
CLayout, \
kPadA, \
kPadB, \
kPadC, \
ck_tile::GemmPipelineScheduler::Intrawave, \
has_hot_loop_, \
tail_number_>>; \
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; \
Run(Kernel{});
if(has_hot_loop)
{ {
args.stride_A = stride_a; // Tail pipeline One to Seven
} if(tail_num == ck_tile::TailNumber::One)
else
{
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
args.stride_A = K; RUN_KERNEL_(true, ck_tile::TailNumber::One);
} }
else else if(tail_num == ck_tile::TailNumber::Full)
{ {
args.stride_A = M; RUN_KERNEL_(true, ck_tile::TailNumber::Full);
} }
}
if(stride_b != 0) if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
args.stride_B = stride_b;
}
else
{
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
args.stride_B = N; if(tail_num == ck_tile::TailNumber::Two)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Two);
}
} }
else if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{ {
args.stride_B = K; if(tail_num == ck_tile::TailNumber::Three)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Three);
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Four);
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Five);
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Six);
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Seven);
}
} }
}
if(stride_c != 0)
{
args.stride_C = stride_c;
} }
else else
{ {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::RowMajor>) // Tail number always 1
{ if(tail_num == ck_tile::TailNumber::One)
args.stride_C = N;
}
else
{ {
args.stride_C = M; RUN_KERNEL_(false, ck_tile::TailNumber::One);
} }
} }
float ave_time = #undef RUN_KERNEL_
gemm_calc<LayoutA, LayoutB, LayoutC>(args, ck_tile::stream_config{nullptr, true});
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
gemm_basic_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with " std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
<< "[" << data_type << "]" << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< "is: \n"; << std::endl;
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
<< std::flush;
return ave_time; return ave_time;
} }
...@@ -209,118 +253,120 @@ int main(int argc, char* argv[]) ...@@ -209,118 +253,120 @@ int main(int argc, char* argv[])
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N). ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor; ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor; ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify
std::vector<int> a_dimensions =
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> b_dimensions =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> c_dimensions =
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
invoke_gemm<ck_tile::half_t, matrix_a_layout, matrix_b_layout, matrix_c_layout>(
a_buf, b_buf, c_buf, arg_parser);
c_buf.FromDevice(c_host_dev.data());
bool pass = true;
if(arg_parser.get_int("v") == 1) ck_tile::index_t batch_size = arg_parser.get_int("b");
{ int n_warmup = arg_parser.get_int("warmup");
// ToDo: Will Add the Element Op (bias) verification in the future. int n_repeat = arg_parser.get_int("repeat");
ck_tile::reference_gemm<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
pass = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::flush;
}
else if(arg_parser.get_int("v") == 2)
{
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0) using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
{ using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>) using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
{
stride_a = K;
}
else
{
stride_a = M;
}
}
if(stride_b == 0) using namespace ck_tile::literals;
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>) auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{ {
stride_b = N; return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
} }
else else
{ {
stride_b = K; return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
} }
} };
if(stride_c == 0) auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{ {
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>) // give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{ {
stride_c = N; return col;
} }
else else
{ {
stride_c = M; return row;
} }
} }
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_size,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions); if(arg_parser.get_int("v") == 1)
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); {
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); c_m_n_host_ref.SetZero();
c_buf.FromDevice(c_host_gpu_ref.data()); ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_host_dev, c_host_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::flush; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2)
{
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
std::cout << std::endl << std::flush; c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass; return pass;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace ck_tile {
namespace literals {
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline constexpr std::size_t operator""_uz(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
inline constexpr std::size_t operator""_zu(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
} // namespace literals
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace ck_tile { namespace ck_tile {
...@@ -14,48 +15,36 @@ template <typename ADataType, ...@@ -14,48 +15,36 @@ template <typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CDataType, typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity, typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity, typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity> typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k, const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n, HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {}, const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const std::size_t M = a_m_k.get_length(0);
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const std::size_t N = b_k_n.get_length(1);
? a_m_k.mDesc.get_lengths()[1] const std::size_t K = a_m_k.get_length(1);
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) auto f_mn = [&](auto m, auto n) {
{ AccDataType v_acc = 0;
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * for(std::size_t k = 0; k < K; ++k)
ck_tile::type_convert<AccDataType>(v_b); {
} ADataType v_a = a_element_op(a_m_k(m, k));
BDataType v_b = b_element_op(b_k_n(k, n));
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
......
...@@ -8,34 +8,29 @@ ...@@ -8,34 +8,29 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct GemmKernel struct GemmKernel
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>; using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using LayoutB = remove_cvref_t<LayoutB_>; using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using LayoutC = remove_cvref_t<LayoutC_>; using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; // using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::CDataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M, N, KBatch);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
...@@ -45,30 +40,30 @@ struct GemmKernel ...@@ -45,30 +40,30 @@ struct GemmKernel
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
void* c_ptr; void* c_ptr;
ck_tile::index_t M; index_t M;
ck_tile::index_t N; index_t N;
ck_tile::index_t K; index_t K;
ck_tile::index_t stride_A; index_t stride_A;
ck_tile::index_t stride_B; index_t stride_B;
ck_tile::index_t stride_C; index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr, const void* b_ptr,
void* c_ptr, void* c_ptr,
ck_tile::index_t M, index_t M,
ck_tile::index_t N, index_t N,
ck_tile::index_t K, index_t K,
ck_tile::index_t stride_A, index_t stride_A,
ck_tile::index_t stride_B, index_t stride_B,
ck_tile::index_t stride_C) index_t stride_C)
{ {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
...@@ -79,12 +74,12 @@ struct GemmKernel ...@@ -79,12 +74,12 @@ struct GemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views // Convert pointers to tensor views
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_start,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{}, number<GemmPipeline::AlignmentA>{},
number<1>{}); number<1>{});
} }
...@@ -93,14 +88,14 @@ struct GemmKernel ...@@ -93,14 +88,14 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_start,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1), make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{}, number<GemmPipeline::AlignmentA>{},
number<1>{}); number<1>{});
} }
}(); }();
auto b_tensor_view = [&]() { auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_start,
...@@ -110,7 +105,7 @@ struct GemmKernel ...@@ -110,7 +105,7 @@ struct GemmKernel
number<1>{}); number<1>{});
} }
else else
{ // Default NK layout {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_start,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, kargs.K),
...@@ -123,8 +118,8 @@ struct GemmKernel ...@@ -123,8 +118,8 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view( auto a_pad_view = pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0, sequence < false,
GemmPipeline::kPadA ? 1 : 0 > {}); GemmPipeline::kPadA ? true : false > {});
auto ABlockWindow = make_tile_window( auto ABlockWindow = make_tile_window(
a_pad_view, a_pad_view,
...@@ -134,8 +129,8 @@ struct GemmKernel ...@@ -134,8 +129,8 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view( auto b_pad_view = pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0, sequence < false,
GemmPipeline::kPadB ? 1 : 0 > {}); GemmPipeline::kPadB ? true : false > {});
auto BBlockWindow = make_tile_window( auto BBlockWindow = make_tile_window(
b_pad_view, b_pad_view,
...@@ -225,15 +220,15 @@ struct GemmKernel ...@@ -225,15 +220,15 @@ struct GemmKernel
} }
} }
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr); CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_start,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{}, number<GemmPipeline::AlignmentC>{},
number<1>{}); number<1>{});
} }
...@@ -242,7 +237,7 @@ struct GemmKernel ...@@ -242,7 +237,7 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_start,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{}, number<GemmPipeline::AlignmentC>{},
number<1>{}); number<1>{});
} }
...@@ -251,13 +246,13 @@ struct GemmKernel ...@@ -251,13 +246,13 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view( auto c_pad_view = pad_tensor_view(
c_tensor_view, c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0, sequence < false,
GemmPipeline::kPadC ? 1 : 0 > {}); GemmPipeline::kPadC ? true : false > {});
auto CBlockWindow_pad = make_tile_window( auto CBlockWindow = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); EpiloguePipeline{}(CBlockWindow, c_block_tile);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = BlockGemmPipelineAgBgCrDefaultPolicy>
struct BlockGemmPipelineAgBgCrMem
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using CBlockTile = typename BlockGemm::CBlockTile;
using I0 = number<0>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA;
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
: 2;
static constexpr index_t LocalPrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
return TailNumber::One;
}
else if(num_loop % PrefetchStages == 2)
{
return TailNumber::Two;
}
else if(num_loop % PrefetchStages == 3)
{
return TailNumber::Three;
}
else if(num_loop % PrefetchStages == 4)
{
return TailNumber::Four;
}
else if(num_loop % PrefetchStages == 5)
{
return TailNumber::Five;
}
else if(num_loop % PrefetchStages == 6)
{
return TailNumber::Six;
}
else if(num_loop % PrefetchStages == 7)
{
return TailNumber::Seven;
}
else
{
return TailNumber::Full;
}
}
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <BlockGemmPipelineScheduler Scheduler>
struct PipelineImpl
{
};
template <>
struct PipelineImpl<BlockGemmPipelineScheduler::Intrawave>
{
template <typename BlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(BlockTile& block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile_raw(block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
NPerBlock ==
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = BlockGemm();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
// prefetch
// global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [2, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
});
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
});
i += PrefetchStages;
} while(i < (num_loop - PrefetchStages));
}
auto HotLoopTail = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
});
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};
// TODO: TailNumber2Number
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else if constexpr(TailNum == TailNumber::Two)
{
HotLoopTail(number<2>{});
}
else if constexpr(TailNum == TailNumber::Three)
{
HotLoopTail(number<3>{});
}
else if constexpr(TailNum == TailNumber::Four)
{
HotLoopTail(number<4>{});
}
else if constexpr(TailNum == TailNumber::Five)
{
HotLoopTail(number<5>{});
}
else if constexpr(TailNum == TailNumber::Six)
{
HotLoopTail(number<6>{});
}
else if constexpr(TailNum == TailNumber::Seven)
{
HotLoopTail(number<7>{});
}
else if constexpr(TailNum == TailNumber::Full)
{
HotLoopTail(number<PrefetchStages>{});
}
return c_block_tile;
}
};
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem,
c_block_tile);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem,
c_block_tile);
}
};
} // namespace ck_tile
...@@ -13,6 +13,9 @@ template <typename ADataType_, ...@@ -13,6 +13,9 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool kPadA_ = false, bool kPadA_ = false,
bool kPadB_ = false, bool kPadB_ = false,
bool kPadC_ = false> bool kPadC_ = false>
...@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem ...@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<ALayout_>;
using BLayout = remove_cvref_t<BLayout_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_; static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = kPadB_;
...@@ -37,18 +44,29 @@ template <typename ADataType_, ...@@ -37,18 +44,29 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
bool kPadA_ = false, typename ALayout_,
bool kPadB_ = false, typename BLayout_,
bool kPadC_ = false, typename CLayout_,
BlockGemmPipelineScheduler Scheduler_ = BlockGemmPipelineScheduler::Intrawave> bool kPadA_ = false,
struct BlockGemmUniversalPipelineProblem bool kPadB_ = false,
bool kPadC_ = false,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = false,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem
{ {
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<ALayout_>;
using BLayout = remove_cvref_t<BLayout_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_; static constexpr bool kPadA = kPadA_;
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
namespace ck_tile { namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 // Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
using BlockGemmPipelineAgBgCrDefaultPolicy = BlockGemmPipelineAgBgCrMemCustomPolicy; using GemmPipelineAgBgCrDefaultPolicy = GemmPipelineAgBgCrMemCustomPolicy;
} // namespace ck_tile } // namespace ck_tile
...@@ -9,14 +9,14 @@ ...@@ -9,14 +9,14 @@
namespace ck_tile { namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 // Default policy for GemmPipelineAGmemBGmemCRegV1
// Maximum Global Memory throughput pipeline with >=32KB data in fly // Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2 // GlobalPrefetchStages: >=2
// LocalPreFillStages: 1 // LocalPreFillStages: 1
// LocalPreFetchStages: 0 // LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1 // LocalSharedMemoryBuffer: 1
struct BlockGemmPipelineAgBgCrMemCustomPolicy struct GemmPipelineAgBgCrMemCustomPolicy
{ {
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
...@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy ...@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
...@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy ...@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy ...@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
...@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy ...@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
enum struct BlockGemmPipelineScheduler enum struct GemmPipelineScheduler
{ {
Intrawave, Intrawave,
Interwave, Interwave,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment