Commit 04006d5f authored by ThomasNing's avatar ThomasNing
Browse files

Fix: Clang Format, API fixed from fmha

parent c2b7f8df
......@@ -14,7 +14,8 @@
/*
create_args is a function
*/
auto create_args(int argc, char* argv[]) {
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "1024", "m dimension")
......@@ -24,9 +25,12 @@ auto create_args(int argc, char* argv[]) {
.insert("stride_b", "0", "stride on apply the n,k B block")
.insert("stride_c", "0", "stride on apply the m,n C block")
.insert("grouped", "0", "bool condition on whether it is a grouped gemm")
.insert("grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm")
.insert(
"grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm")
.insert(
"grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm")
.insert(
"grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("v", "1", "cpu validation or not")
.insert("e", "1e-5", "epsilon")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
......@@ -40,7 +44,8 @@ auto create_args(int argc, char* argv[]) {
}
template <typename Layouts>
float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later.
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
......@@ -61,28 +66,34 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
constexpr ck_tile::index_t kBlockPerCu = 1;
// ===============================================
using Shape = ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
using Shape = ck_tile::TileGemmShapeNewGemm<
ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>
>;
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>;
using PipelineProblem = ck_tile::BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape,
kPadA, kPadB, kPadC>;
using PipelineProblem = ck_tile::
BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape, kPadA, kPadB, kPadC>;
// The GemmPipeline should also come from the Codegen.
using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
using GemmEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
ODataType, kPadA, kPadB>>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, ODataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs(
args.p_x, args.p_y, args.p_z, args.batch_size, args.epsilon, args.M, args.N,
args.K, args.stride_A, args.stride_B, args.stride_C
);
auto kargs = Kernel::MakeKargs(args.p_x,
args.p_y,
args.p_z,
args.batch_size,
args.epsilon,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size);
constexpr dim3 blocks = Kernel::BlockSize();
......@@ -94,13 +105,16 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
}
template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::DeviceMem& y_buf,
ck_tile::DeviceMem& z_buf,
const ck_tile::ArgParser& arg_parser){
const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
if (data_type != DataTypeTraits<DataType>::name) {
if(data_type != DataTypeTraits<DataType>::name)
{
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
<< data_type << std::endl;
return -1; // Or handle the error appropriately
......@@ -127,48 +141,67 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if (stride_a != 0) {
if(stride_a != 0)
{
args.stride_A = stride_a;
} else {
args.stride_A = [&](){
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) {
}
else
{
args.stride_A = [&]() {
if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return M;
} else {
}
else
{
return K;
}
}();
}
if (stride_b != 0) {
if(stride_b != 0)
{
args.stride_B = stride_b;
} else {
args.stride_B = [&](){
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) {
}
else
{
args.stride_B = [&]() {
if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return N;
} else {
}
else
{
return K;
}
}();
}
if(stride_c != 0) {
if(stride_c != 0)
{
args.stride_C = stride_c;
} else {
args.stride_C = [&](){
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM) {
}
else
{
args.stride_C = [&]() {
if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return M;
} else {
}
else
{
return N;
}
}();
}
float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * K + sizeof(YDataType) * N * K+
sizeof(ODataType) * M * N;
std::size_t num_byte =
sizeof(XDataType) * M * K + sizeof(YDataType) * N * K + sizeof(ODataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with " << "[" << data_type << "]"
std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
<< "is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
......@@ -177,7 +210,8 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
return ave_time;
}
int main(int argc, char* argv[]) {
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
......@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) {
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK) ?
std::vector<int>{M, K} : std::vector<int>{K, M};
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK) ?
std::vector<int>{N, K} : std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN) ?
std::vector<int>{M, N} : std::vector<int>{N, M};
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<XDataType> x_host(x_dimensions);
ck_tile::HostTensor<YDataType> y_host(y_dimensions);
......@@ -217,7 +254,8 @@ int main(int argc, char* argv[]) {
x_buf.ToDevice(x_host.data());
y_buf.ToDevice(y_host.data());
if(grouped_enable || following_op_descrp != "no") {
if(grouped_enable || following_op_descrp != "no")
{
std::cerr << "Other category of the GEMM is unsupported for now!" << std::endl;
return -1;
}
......@@ -226,7 +264,8 @@ int main(int argc, char* argv[]) {
bool pass = true;
if(arg_parser.get_bool("v")) {
if(arg_parser.get_bool("v"))
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>(
x_host, y_host, z_host_ref, matrix_a_layout);
......@@ -240,6 +279,5 @@ int main(int argc, char* argv[]) {
std::cout << std::endl << std::flush;
return !pass;
}
......@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
......@@ -15,37 +14,41 @@ template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t> {
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t; //type convert
using ODataType = ck_tile::half_t; // type convert
// ToDo: Add more bias config to support different categories of GEMM.
};
template<ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B,
ck_tile::MatrixCLayout C>
struct LayoutConfig {
template <ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B, ck_tile::MatrixCLayout C>
struct LayoutConfig
{
static constexpr ck_tile::MatrixALayout LayoutA = A;
static constexpr ck_tile::MatrixBLayout LayoutB = B;
static constexpr ck_tile::MatrixCLayout LayoutC = C;
};
template<typename T>
template <typename T>
struct DataTypeTraits;
template<>
struct DataTypeTraits<float> {
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "float";
};
template<>
struct DataTypeTraits<double> {
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "double";
};
template<>
struct DataTypeTraits<ck_tile::half_t> {
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
......@@ -57,7 +60,8 @@ using YDataType = Types::YDataType;
using AccDataType = Types::AccDataType;
using ODataType = Types::ODataType;
struct gemm_basic_args {
struct gemm_basic_args
{
const void* p_x;
const void* p_y;
void* p_z;
......
......@@ -67,7 +67,7 @@ check_err(const Range& out,
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 4190; i < ref.size(); ++i)
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i);
......@@ -127,7 +127,7 @@ check_err(const Range& out,
double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i)
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
......@@ -186,7 +186,7 @@ check_err(const Range& out,
int err_count = 0;
double err = 0;
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
for(std::size_t i = 4190; i < ref.size(); ++i)
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
......@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i)
for(std::size_t i = 0; i < ref.size(); ++i)
{
const fp8_t o_fp8 = *std::next(std::begin(out), i);
const fp8_t r_fp8 = *std::next(std::begin(ref), i);
......@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i)
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
......
......@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
typename Problem::KDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::AccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kVHeaddim,
Problem::BlockFmhaShape::BlockTile::kK1
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK1>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::AccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK2
>,
Problem::BlockFmhaShape::Gemm2BlockWarps_,
Problem::BlockFmhaShape::Gemm2WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK2>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
......@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::AccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK3
>,
Problem::BlockFmhaShape::Gemm3BlockWarps_,
Problem::BlockFmhaShape::Gemm3WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK3>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::AccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK4
>,
Problem::BlockFmhaShape::Gemm4BlockWarps_,
Problem::BlockFmhaShape::Gemm4WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK4>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
......@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType,
typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN1,
Problem::BlockFmhaShape::BlockTile::kK1
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
Problem::BlockFmhaShape::BlockTile::kK1>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
......@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} else {
}
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
......
......@@ -12,8 +12,12 @@
namespace ck_tile {
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, typename Layouts_>
struct GemmKernel {
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
typename Layouts_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
......@@ -25,16 +29,17 @@ struct GemmKernel {
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) {
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size);
printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z);
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
struct GemmCommonKargs {
struct GemmCommonKargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
......@@ -60,15 +65,19 @@ struct GemmKernel {
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C) {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C};
ck_tile::index_t stride_C)
{
return GemmCommonKargs{
a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() {
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const {
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_tile_m, i_tile_n, i_batch] = TilePartitioner{}();
const index_t i_m = __builtin_amdgcn_readfirstlane(i_tile_m * TilePartitioner::kM);
const index_t i_n = __builtin_amdgcn_readfirstlane(i_tile_n * TilePartitioner::kN);
......@@ -76,66 +85,96 @@ struct GemmKernel {
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&](){
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) {
auto a_tensor_view = [&]() {
if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{}, number<1>{});
} else {
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{}, number<1>{});
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
}();
auto b_tensor_view = [&](){
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) {
auto b_tensor_view = [&]() {
if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{}, number<1>{});
} else { // Default NK layout
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
else
{ // Default NK layout
return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{}, number<1>{});
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
}();
auto ABlockWindow = make_tile_window(a_tensor_view, make_tuple(number<TilePartitioner::kM>{},
number<TilePartitioner::kK>{}), {i_m, 0});
auto ABlockWindow = make_tile_window(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto BBlockWindow = make_tile_window(b_tensor_view, make_tuple(number<TilePartitioner::kN>{},
number<TilePartitioner::kK>{}), {i_n, 0});
auto BBlockWindow = make_tile_window(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(
ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = [&](){
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM){
auto c_tensor_view = [&]() {
if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{}, number<1>{});
} else {
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{}, number<1>{});
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
}();
auto CBlockWindow = make_tile_window(c_tensor_view, make_tuple(number<TilePartitioner::kM>{},
number<TilePartitioner::kN>{}), {i_m, i_n});
auto CBlockWindow = make_tile_window(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
// epilogue.
EpiloguePipeline{}(CBlockWindow, acc);
}
};
}
} // namespace ck_tile
......@@ -4,36 +4,21 @@
#pragma once
namespace ck_tile {
enum struct MatrixALayout {
enum struct MatrixALayout
{
MK, // Row-major layout for matrix A (default)
KM // Column-major layout for matrix A
};
};
enum struct MatrixBLayout {
enum struct MatrixBLayout
{
NK, // Row-major layout for matrix B (default)
KN // Column-major layout for matrix B
};
};
enum struct MatrixCLayout {
enum struct MatrixCLayout
{
MN, // Row-major layout for matrix C (default)
NM // Column-major layout for matrix C
};
// Function to convert string to MatrixALayout
inline MatrixALayout parse_layout_a(const std::string& layout) {
if (layout == "KM") return MatrixALayout::KM;
return MatrixALayout::MK; // Default to MK if not specified as KM
}
// Function to convert string to MatrixBLayout
inline MatrixBLayout parse_layout_b(const std::string& layout) {
if (layout == "KN") return MatrixBLayout::KN;
return MatrixBLayout::NK; // Default to NK if not specified as KN
}
// Function to convert string to MatrixBLayout
inline MatrixCLayout parse_layout_c(const std::string& layout) {
if (layout == "NM") return MatrixCLayout::NM;
return MatrixCLayout::MN; // Default to MN if not specified as NM
}
};
} // namespace ck_tile
......@@ -6,27 +6,30 @@
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner {
template <typename BlockGemmShape_>
struct GemmTilePartitioner
{
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M, ck_tile::index_t N,
ck_tile::index_t batch_size) {
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
{
ck_tile::index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimY = (N + kN - 1) / kN;
ck_tile::index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ);
}
CK_TILE_DEVICE auto operator()() {
CK_TILE_DEVICE auto operator()()
{
const index_t i_GridDimX = blockIdx.x;
const index_t i_GridDimY = blockIdx.y;
const index_t i_GridDimZ = blockIdx.z;
return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ);
}
};
};
} // namespace ck_tile
......@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() {
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
......@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
}
index_t iCounter = num_loop - 1;
while(iCounter > 0) {
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
......@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--;
}
// tail
......
......@@ -93,21 +93,24 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() {
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() {
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() {
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
......
......@@ -29,7 +29,6 @@ struct BlockGemmPipelineProblem
static constexpr index_t AlignmentA = kPadA ? 16 / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? 16 / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? 16 / sizeof(CDataType) : 1;
};
} // namespace ck_tile
......@@ -7,17 +7,21 @@
namespace ck_tile {
template <typename BlockTile_,
typename BlockWarps_,
typename WarpTile_>
struct TileGemmShape
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape {
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShapeNewGemm
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment