Commit 04006d5f authored by ThomasNing's avatar ThomasNing
Browse files

Fix: Clang Format, API fixed from fmha

parent c2b7f8df
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
/* /*
create_args is a function create_args is a function
*/ */
auto create_args(int argc, char* argv[]) { auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("b", "1", "batch size")
.insert("m", "1024", "m dimension") .insert("m", "1024", "m dimension")
...@@ -24,9 +25,12 @@ auto create_args(int argc, char* argv[]) { ...@@ -24,9 +25,12 @@ auto create_args(int argc, char* argv[]) {
.insert("stride_b", "0", "stride on apply the n,k B block") .insert("stride_b", "0", "stride on apply the n,k B block")
.insert("stride_c", "0", "stride on apply the m,n C block") .insert("stride_c", "0", "stride on apply the m,n C block")
.insert("grouped", "0", "bool condition on whether it is a grouped gemm") .insert("grouped", "0", "bool condition on whether it is a grouped gemm")
.insert("grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm") .insert(
.insert("grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm") "grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm") .insert(
"grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm")
.insert(
"grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
...@@ -40,7 +44,8 @@ auto create_args(int argc, char* argv[]) { ...@@ -40,7 +44,8 @@ auto create_args(int argc, char* argv[]) {
} }
template <typename Layouts> template <typename Layouts>
float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) { float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later. // ToDo: This will be modified by the codegen code later.
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 128;
...@@ -61,28 +66,34 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) { ...@@ -61,28 +66,34 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
// =============================================== // ===============================================
using Shape = ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, using Shape = ck_tile::TileGemmShapeNewGemm<
ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile> ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
>;
using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>; using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>;
using PipelineProblem = ck_tile::BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape, using PipelineProblem = ck_tile::
kPadA, kPadB, kPadC>; BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape, kPadA, kPadB, kPadC>;
// The GemmPipeline should also come from the Codegen. // The GemmPipeline should also come from the Codegen.
using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>; using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
using GemmEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType, using GemmEpilogue = ck_tile::Default2DEpilogue<
ODataType, kPadA, kPadB>>; ck_tile::Default2DEpilogueProblem<AccDataType, ODataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs( auto kargs = Kernel::MakeKargs(args.p_x,
args.p_x, args.p_y, args.p_z, args.batch_size, args.epsilon, args.M, args.N, args.p_y,
args.K, args.stride_A, args.stride_B, args.stride_C args.p_z,
); args.batch_size,
args.epsilon,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size); const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
...@@ -94,13 +105,16 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) { ...@@ -94,13 +105,16 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
} }
template <typename DataType, typename Layouts> template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf, float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::DeviceMem& y_buf,
ck_tile::DeviceMem& z_buf, ck_tile::DeviceMem& z_buf,
const ck_tile::ArgParser& arg_parser){ const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
if (data_type != DataTypeTraits<DataType>::name) { if(data_type != DataTypeTraits<DataType>::name)
{
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got " std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
<< data_type << std::endl; << data_type << std::endl;
return -1; // Or handle the error appropriately return -1; // Or handle the error appropriately
...@@ -127,48 +141,67 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf, ...@@ -127,48 +141,67 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
args.K = K; args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K. // Only set stride_M and stride_N if they are non-zero and not equal to K.
if (stride_a != 0) { if(stride_a != 0)
{
args.stride_A = stride_a; args.stride_A = stride_a;
} else { }
args.stride_A = [&](){ else
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) { {
args.stride_A = [&]() {
if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return M; return M;
} else { }
else
{
return K; return K;
} }
}(); }();
} }
if (stride_b != 0) { if(stride_b != 0)
{
args.stride_B = stride_b; args.stride_B = stride_b;
} else { }
args.stride_B = [&](){ else
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) { {
args.stride_B = [&]() {
if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return N; return N;
} else { }
else
{
return K; return K;
} }
}(); }();
} }
if(stride_c != 0) { if(stride_c != 0)
{
args.stride_C = stride_c; args.stride_C = stride_c;
} else { }
args.stride_C = [&](){ else
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM) { {
args.stride_C = [&]() {
if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return M; return M;
} else { }
else
{
return N; return N;
} }
}(); }();
} }
float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true}); float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * K + sizeof(YDataType) * N * K+ std::size_t num_byte =
sizeof(ODataType) * M * N; sizeof(XDataType) * M * K + sizeof(YDataType) * N * K + sizeof(ODataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with " << "[" << data_type << "]" std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K << "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
<< "is: \n"; << "is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
...@@ -177,7 +210,8 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf, ...@@ -177,7 +210,8 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
return ave_time; return ave_time;
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
...@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) { ...@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) {
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>; using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify // host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK) ? std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
std::vector<int>{M, K} : std::vector<int>{K, M}; ? std::vector<int>{M, K}
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK) ? : std::vector<int>{K, M};
std::vector<int>{N, K} : std::vector<int>{K, N}; std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN) ? ? std::vector<int>{N, K}
std::vector<int>{M, N} : std::vector<int>{N, M}; : std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<XDataType> x_host(x_dimensions); ck_tile::HostTensor<XDataType> x_host(x_dimensions);
ck_tile::HostTensor<YDataType> y_host(y_dimensions); ck_tile::HostTensor<YDataType> y_host(y_dimensions);
...@@ -217,7 +254,8 @@ int main(int argc, char* argv[]) { ...@@ -217,7 +254,8 @@ int main(int argc, char* argv[]) {
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
y_buf.ToDevice(y_host.data()); y_buf.ToDevice(y_host.data());
if(grouped_enable || following_op_descrp != "no") { if(grouped_enable || following_op_descrp != "no")
{
std::cerr << "Other category of the GEMM is unsupported for now!" << std::endl; std::cerr << "Other category of the GEMM is unsupported for now!" << std::endl;
return -1; return -1;
} }
...@@ -226,7 +264,8 @@ int main(int argc, char* argv[]) { ...@@ -226,7 +264,8 @@ int main(int argc, char* argv[]) {
bool pass = true; bool pass = true;
if(arg_parser.get_bool("v")) { if(arg_parser.get_bool("v"))
{
// ToDo: Will Add the Element Op (bias) verification in the future. // ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>( ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>(
x_host, y_host, z_host_ref, matrix_a_layout); x_host, y_host, z_host_ref, matrix_a_layout);
...@@ -240,6 +279,5 @@ int main(int argc, char* argv[]) { ...@@ -240,6 +279,5 @@ int main(int argc, char* argv[]) {
std::cout << std::endl << std::flush; std::cout << std::endl << std::flush;
return !pass; return !pass;
} }
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
...@@ -15,37 +14,41 @@ template <typename DataType> ...@@ -15,37 +14,41 @@ template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> { struct GemmBasicTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using YDataType = ck_tile::half_t;
using AccDataType = float; using AccDataType = float;
using ODataType = ck_tile::half_t; //type convert using ODataType = ck_tile::half_t; // type convert
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
template<ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B, template <ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B, ck_tile::MatrixCLayout C>
ck_tile::MatrixCLayout C> struct LayoutConfig
struct LayoutConfig { {
static constexpr ck_tile::MatrixALayout LayoutA = A; static constexpr ck_tile::MatrixALayout LayoutA = A;
static constexpr ck_tile::MatrixBLayout LayoutB = B; static constexpr ck_tile::MatrixBLayout LayoutB = B;
static constexpr ck_tile::MatrixCLayout LayoutC = C; static constexpr ck_tile::MatrixCLayout LayoutC = C;
}; };
template<typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
template<> template <>
struct DataTypeTraits<float> { struct DataTypeTraits<float>
{
static constexpr const char* name = "float"; static constexpr const char* name = "float";
}; };
template<> template <>
struct DataTypeTraits<double> { struct DataTypeTraits<double>
{
static constexpr const char* name = "double"; static constexpr const char* name = "double";
}; };
template<> template <>
struct DataTypeTraits<ck_tile::half_t> { struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
...@@ -57,7 +60,8 @@ using YDataType = Types::YDataType; ...@@ -57,7 +60,8 @@ using YDataType = Types::YDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using ODataType = Types::ODataType; using ODataType = Types::ODataType;
struct gemm_basic_args { struct gemm_basic_args
{
const void* p_x; const void* p_x;
const void* p_y; const void* p_y;
void* p_z; void* p_z;
......
...@@ -67,7 +67,7 @@ check_err(const Range& out, ...@@ -67,7 +67,7 @@ check_err(const Range& out,
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<double>::min(); double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = *std::next(std::begin(out), i); const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i); const double r = *std::next(std::begin(ref), i);
...@@ -127,7 +127,7 @@ check_err(const Range& out, ...@@ -127,7 +127,7 @@ check_err(const Range& out,
double err = 0; double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type. // TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
...@@ -186,7 +186,7 @@ check_err(const Range& out, ...@@ -186,7 +186,7 @@ check_err(const Range& out,
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min()); double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
...@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const fp8_t o_fp8 = *std::next(std::begin(out), i); const fp8_t o_fp8 = *std::next(std::begin(out), i);
const fp8_t r_fp8 = *std::next(std::begin(ref), i); const fp8_t r_fp8 = *std::next(std::begin(ref), i);
...@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
......
...@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::AccDataType,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0 Problem::BlockFmhaShape::BlockTile::kK0>>;
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::OGradDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::OGradDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kVHeaddim, Problem::BlockFmhaShape::BlockTile::kVHeaddim,
Problem::BlockFmhaShape::BlockTile::kK1 Problem::BlockFmhaShape::BlockTile::kK1>>;
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::OGradDataType,
BlockGemmPipelineProblem<typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::VDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK2 Problem::BlockFmhaShape::BlockTile::kK2>>;
>,
Problem::BlockFmhaShape::Gemm2BlockWarps_,
Problem::BlockFmhaShape::Gemm2WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> && if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
...@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::QDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::QDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim, Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK3 Problem::BlockFmhaShape::BlockTile::kK3>>;
>,
Problem::BlockFmhaShape::Gemm3BlockWarps_,
Problem::BlockFmhaShape::Gemm3WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::KDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim, Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK4 Problem::BlockFmhaShape::BlockTile::kK4>>;
>,
Problem::BlockFmhaShape::Gemm4BlockWarps_,
Problem::BlockFmhaShape::Gemm4WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
...@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0 Problem::BlockFmhaShape::BlockTile::kK0>>;
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0 Problem::BlockFmhaShape::BlockTile::kK0>>;
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType,
BlockGemmPipelineProblem<typename Problem::PDataType, typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize,
typename Problem::VDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN1, Problem::BlockFmhaShape::BlockTile::kN1,
Problem::BlockFmhaShape::BlockTile::kK1 Problem::BlockFmhaShape::BlockTile::kK1>>;
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
std::is_same_v<typename Problem::CDataType, float>) std::is_same_v<typename Problem::CDataType, float>)
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} else { }
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution."); static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
} }
} }
......
...@@ -12,8 +12,12 @@ ...@@ -12,8 +12,12 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, typename Layouts_> template <typename TilePartitioner_,
struct GemmKernel { typename GemmPipeline_,
typename EpiloguePipeline_,
typename Layouts_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
...@@ -25,16 +29,17 @@ struct GemmKernel { ...@@ -25,16 +29,17 @@ struct GemmKernel {
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size); auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size);
printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z); printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z);
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M_size, N_size, Batch_size);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
struct GemmCommonKargs { struct GemmCommonKargs
{
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
void* c_ptr; void* c_ptr;
...@@ -60,15 +65,19 @@ struct GemmKernel { ...@@ -60,15 +65,19 @@ struct GemmKernel {
ck_tile::index_t K, ck_tile::index_t K,
ck_tile::index_t stride_A, ck_tile::index_t stride_A,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C) { ck_tile::index_t stride_C)
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C}; {
return GemmCommonKargs{
a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const { CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_tile_m, i_tile_n, i_batch] = TilePartitioner{}(); const auto [i_tile_m, i_tile_n, i_batch] = TilePartitioner{}();
const index_t i_m = __builtin_amdgcn_readfirstlane(i_tile_m * TilePartitioner::kM); const index_t i_m = __builtin_amdgcn_readfirstlane(i_tile_m * TilePartitioner::kM);
const index_t i_n = __builtin_amdgcn_readfirstlane(i_tile_n * TilePartitioner::kN); const index_t i_n = __builtin_amdgcn_readfirstlane(i_tile_n * TilePartitioner::kN);
...@@ -76,66 +85,96 @@ struct GemmKernel { ...@@ -76,66 +85,96 @@ struct GemmKernel {
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views // Convert pointers to tensor views
auto a_tensor_view = [&](){ auto a_tensor_view = [&]() {
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) { if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A), a_start,
number<GemmPipeline::AlignmentA>{}, number<1>{}); make_tuple(kargs.M, kargs.K),
} else { make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), a_start,
number<GemmPipeline::AlignmentA>{}, number<1>{}); make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
} }
}(); }();
auto b_tensor_view = [&](){ auto b_tensor_view = [&]() {
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) { if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), b_start,
number<GemmPipeline::AlignmentB>{}, number<1>{}); make_tuple(kargs.N, kargs.K),
} else { // Default NK layout make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
else
{ // Default NK layout
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), b_start,
number<GemmPipeline::AlignmentB>{}, number<1>{}); make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<1>{});
} }
}(); }();
auto ABlockWindow = make_tile_window(a_tensor_view, make_tuple(number<TilePartitioner::kM>{}, auto ABlockWindow = make_tile_window(
number<TilePartitioner::kK>{}), {i_m, 0}); a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto BBlockWindow = make_tile_window(b_tensor_view, make_tuple(number<TilePartitioner::kN>{}, auto BBlockWindow = make_tile_window(
number<TilePartitioner::kK>{}), {i_n, 0}); b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr); ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr); CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = [&](){ auto c_tensor_view = [&]() {
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM){ if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), c_start,
number<GemmPipeline::AlignmentC>{}, number<1>{}); make_tuple(kargs.M, kargs.N),
} else { make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), c_start,
number<GemmPipeline::AlignmentC>{}, number<1>{}); make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
} }
}(); }();
auto CBlockWindow = make_tile_window(c_tensor_view, make_tuple(number<TilePartitioner::kM>{}, auto CBlockWindow = make_tile_window(
number<TilePartitioner::kN>{}), {i_m, i_n}); c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
// epilogue. // epilogue.
EpiloguePipeline{}(CBlockWindow, acc); EpiloguePipeline{}(CBlockWindow, acc);
} }
}; };
} } // namespace ck_tile
...@@ -4,36 +4,21 @@ ...@@ -4,36 +4,21 @@
#pragma once #pragma once
namespace ck_tile { namespace ck_tile {
enum struct MatrixALayout { enum struct MatrixALayout
{
MK, // Row-major layout for matrix A (default) MK, // Row-major layout for matrix A (default)
KM // Column-major layout for matrix A KM // Column-major layout for matrix A
}; };
enum struct MatrixBLayout { enum struct MatrixBLayout
{
NK, // Row-major layout for matrix B (default) NK, // Row-major layout for matrix B (default)
KN // Column-major layout for matrix B KN // Column-major layout for matrix B
}; };
enum struct MatrixCLayout { enum struct MatrixCLayout
{
MN, // Row-major layout for matrix C (default) MN, // Row-major layout for matrix C (default)
NM // Column-major layout for matrix C NM // Column-major layout for matrix C
}; };
// Function to convert string to MatrixALayout
inline MatrixALayout parse_layout_a(const std::string& layout) {
if (layout == "KM") return MatrixALayout::KM;
return MatrixALayout::MK; // Default to MK if not specified as KM
}
// Function to convert string to MatrixBLayout
inline MatrixBLayout parse_layout_b(const std::string& layout) {
if (layout == "KN") return MatrixBLayout::KN;
return MatrixBLayout::NK; // Default to NK if not specified as KN
}
// Function to convert string to MatrixBLayout
inline MatrixCLayout parse_layout_c(const std::string& layout) {
if (layout == "NM") return MatrixCLayout::NM;
return MatrixCLayout::MN; // Default to MN if not specified as NM
}
} // namespace ck_tile } // namespace ck_tile
...@@ -6,27 +6,30 @@ ...@@ -6,27 +6,30 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <typename BlockGemmShape_> template <typename BlockGemmShape_>
struct GemmTilePartitioner { struct GemmTilePartitioner
{
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM; static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN; static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK; static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M, ck_tile::index_t N, CK_TILE_HOST static constexpr auto
ck_tile::index_t batch_size) { GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
{
ck_tile::index_t GridDimX = (M + kM - 1) / kM; ck_tile::index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimY = (N + kN - 1) / kN; ck_tile::index_t GridDimY = (N + kN - 1) / kN;
ck_tile::index_t GridDimZ = batch_size; ck_tile::index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ); return dim3(GridDimX, GridDimY, GridDimZ);
} }
CK_TILE_DEVICE auto operator()() { CK_TILE_DEVICE auto operator()()
{
const index_t i_GridDimX = blockIdx.x; const index_t i_GridDimX = blockIdx.x;
const index_t i_GridDimY = blockIdx.y; const index_t i_GridDimY = blockIdx.y;
const index_t i_GridDimZ = blockIdx.z; const index_t i_GridDimZ = blockIdx.z;
return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ); return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
...@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
while(iCounter > 0) { while(iCounter > 0)
{
// global read i + 1 // global read i + 1
a_block_tile = load_tile(a_copy_dram_window); a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window);
...@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--; iCounter--;
} }
// tail // tail
......
...@@ -93,21 +93,24 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -93,21 +93,24 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b; return smem_size_b;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0; index_t smem_size = 0;
......
...@@ -29,7 +29,6 @@ struct BlockGemmPipelineProblem ...@@ -29,7 +29,6 @@ struct BlockGemmPipelineProblem
static constexpr index_t AlignmentA = kPadA ? 16 / sizeof(ADataType) : 1; static constexpr index_t AlignmentA = kPadA ? 16 / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? 16 / sizeof(BDataType) : 1; static constexpr index_t AlignmentB = kPadB ? 16 / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? 16 / sizeof(CDataType) : 1; static constexpr index_t AlignmentC = kPadC ? 16 / sizeof(CDataType) : 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -7,17 +7,21 @@ ...@@ -7,17 +7,21 @@
namespace ck_tile { namespace ck_tile {
template <typename BlockTile_, template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
typename BlockWarps_, struct TileGemmShape {
typename WarpTile_> static constexpr index_t kM = kMPerTile;
struct TileGemmShape static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShapeNewGemm
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>; using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>; using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{}); static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kN = BlockTile::at(number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment