Commit 04006d5f authored by ThomasNing's avatar ThomasNing
Browse files

Fix: Clang Format, API fixed from fmha

parent c2b7f8df
...@@ -12,80 +12,91 @@ ...@@ -12,80 +12,91 @@
#include <tuple> #include <tuple>
/* /*
create_args is a function create_args is a function
*/ */
auto create_args(int argc, char* argv[]) { auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("b", "1", "batch size")
.insert("m", "1024", "m dimension") .insert("m", "1024", "m dimension")
.insert("n", "2048", "n dimension") .insert("n", "2048", "n dimension")
.insert("k", "32", "k dimension") .insert("k", "32", "k dimension")
.insert("stride_a", "0", "stride on apply the m,k A block") .insert("stride_a", "0", "stride on apply the m,k A block")
.insert("stride_b", "0", "stride on apply the n,k B block") .insert("stride_b", "0", "stride on apply the n,k B block")
.insert("stride_c", "0", "stride on apply the m,n C block") .insert("stride_c", "0", "stride on apply the m,n C block")
.insert("grouped", "0", "bool condition on whether it is a grouped gemm") .insert("grouped", "0", "bool condition on whether it is a grouped gemm")
.insert("grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm") .insert(
.insert("grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm") "grouped_dimension_m", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm") .insert(
.insert("v", "1", "cpu validation or not") "grouped_dimension_n", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("e", "1e-5", "epsilon") .insert(
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") "grouped_dimension_k", "0", "Fill in the desired dimension when enable grouped gemm")
.insert("following_op", "no", "combined_op. bias/relu/gelu...") .insert("v", "1", "cpu validation or not")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("e", "1e-5", "epsilon")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("following_op", "no", "combined_op. bias/relu/gelu...")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename Layouts> template <typename Layouts>
float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) { float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later. // ToDo: This will be modified by the codegen code later.
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kPadC = false; constexpr bool kPadC = false;
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
// =============================================== // ===============================================
using Shape = ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, using Shape = ck_tile::TileGemmShapeNewGemm<
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile> ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>; using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>;
using PipelineProblem = ck_tile::BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape, using PipelineProblem = ck_tile::
kPadA, kPadB, kPadC>; BlockGemmPipelineProblem<XDataType, YDataType, AccDataType, Shape, kPadA, kPadB, kPadC>;
// The GemmPipeline should also come from the Codegen. // The GemmPipeline should also come from the Codegen.
using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>; using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
using GemmEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType, using GemmEpilogue = ck_tile::Default2DEpilogue<
ODataType, kPadA, kPadB>>; ck_tile::Default2DEpilogueProblem<AccDataType, ODataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs( auto kargs = Kernel::MakeKargs(args.p_x,
args.p_x, args.p_y, args.p_z, args.batch_size, args.epsilon, args.M, args.N, args.p_y,
args.K, args.stride_A, args.stride_B, args.stride_C args.p_z,
); args.batch_size,
args.epsilon,
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size); args.M,
constexpr dim3 blocks = Kernel::BlockSize(); args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel( float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
...@@ -94,81 +105,103 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) { ...@@ -94,81 +105,103 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
} }
template <typename DataType, typename Layouts> template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf, float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::DeviceMem& z_buf, ck_tile::DeviceMem& y_buf,
const ck_tile::ArgParser& arg_parser){ ck_tile::DeviceMem& z_buf,
const ck_tile::ArgParser& arg_parser)
std::string data_type = arg_parser.get_str("prec"); {
if (data_type != DataTypeTraits<DataType>::name) { std::string data_type = arg_parser.get_str("prec");
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
if(data_type != DataTypeTraits<DataType>::name)
{
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
<< data_type << std::endl; << data_type << std::endl;
return -1; // Or handle the error appropriately return -1; // Or handle the error appropriately
} }
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args; gemm_basic_args args;
args.p_x = x_buf.GetDeviceBuffer(); args.p_x = x_buf.GetDeviceBuffer();
args.p_y = y_buf.GetDeviceBuffer(); args.p_y = y_buf.GetDeviceBuffer();
args.p_z = z_buf.GetDeviceBuffer(); args.p_z = z_buf.GetDeviceBuffer();
args.epsilon = epsilon; args.epsilon = epsilon;
args.batch_size = batch_size; args.batch_size = batch_size;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K. // Only set stride_M and stride_N if they are non-zero and not equal to K.
if (stride_a != 0) { if(stride_a != 0)
{
args.stride_A = stride_a; args.stride_A = stride_a;
} else { }
args.stride_A = [&](){ else
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) { {
args.stride_A = [&]() {
if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return M; return M;
} else { }
else
{
return K; return K;
} }
}(); }();
} }
if (stride_b != 0) { if(stride_b != 0)
{
args.stride_B = stride_b; args.stride_B = stride_b;
} else { }
args.stride_B = [&](){ else
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) { {
args.stride_B = [&]() {
if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return N; return N;
} else { }
else
{
return K; return K;
} }
}(); }();
} }
if(stride_c != 0) { if(stride_c != 0)
{
args.stride_C = stride_c; args.stride_C = stride_c;
} else { }
args.stride_C = [&](){ else
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM) { {
args.stride_C = [&]() {
if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return M; return M;
} else { }
else
{
return N; return N;
} }
}(); }();
} }
float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true}); float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * K + sizeof(YDataType) * N * K+ std::size_t num_byte =
sizeof(ODataType) * M * N; sizeof(XDataType) * M * K + sizeof(YDataType) * N * K + sizeof(ODataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with " << "[" << data_type << "]" std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K << "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
<< "is: \n"; << "is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
...@@ -177,16 +210,17 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf, ...@@ -177,16 +210,17 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
return ave_time; return ave_time;
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
bool grouped_enable = arg_parser.get_bool("grouped"); bool grouped_enable = arg_parser.get_bool("grouped");
std::string following_op_descrp = arg_parser.get_str("following_op"); std::string following_op_descrp = arg_parser.get_str("following_op");
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK; constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK;
constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK; constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK;
...@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) { ...@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) {
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>; using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify // host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK) ? std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
std::vector<int>{M, K} : std::vector<int>{K, M}; ? std::vector<int>{M, K}
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK) ? : std::vector<int>{K, M};
std::vector<int>{N, K} : std::vector<int>{K, N}; std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN) ? ? std::vector<int>{N, K}
std::vector<int>{M, N} : std::vector<int>{N, M}; : std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<XDataType> x_host(x_dimensions); ck_tile::HostTensor<XDataType> x_host(x_dimensions);
ck_tile::HostTensor<YDataType> y_host(y_dimensions); ck_tile::HostTensor<YDataType> y_host(y_dimensions);
...@@ -209,7 +246,7 @@ int main(int argc, char* argv[]) { ...@@ -209,7 +246,7 @@ int main(int argc, char* argv[]) {
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(y_host); ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(y_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem z_buf(z_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem z_buf(z_host_dev.get_element_space_size_in_bytes());
...@@ -217,16 +254,18 @@ int main(int argc, char* argv[]) { ...@@ -217,16 +254,18 @@ int main(int argc, char* argv[]) {
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
y_buf.ToDevice(y_host.data()); y_buf.ToDevice(y_host.data());
if(grouped_enable || following_op_descrp != "no") { if(grouped_enable || following_op_descrp != "no")
{
std::cerr << "Other category of the GEMM is unsupported for now!" << std::endl; std::cerr << "Other category of the GEMM is unsupported for now!" << std::endl;
return -1; return -1;
} }
OperatorExecution<ck_tile::half_t, Layouts>(x_buf, y_buf, z_buf, arg_parser); OperatorExecution<ck_tile::half_t, Layouts>(x_buf, y_buf, z_buf, arg_parser);
bool pass = true; bool pass = true;
if(arg_parser.get_bool("v")) { if(arg_parser.get_bool("v"))
{
// ToDo: Will Add the Element Op (bias) verification in the future. // ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>( ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>(
x_host, y_host, z_host_ref, matrix_a_layout); x_host, y_host, z_host_ref, matrix_a_layout);
...@@ -239,7 +278,6 @@ int main(int argc, char* argv[]) { ...@@ -239,7 +278,6 @@ int main(int argc, char* argv[]) {
} }
std::cout << std::endl << std::flush; std::cout << std::endl << std::flush;
return !pass; return !pass;
} }
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
...@@ -15,49 +14,54 @@ template <typename DataType> ...@@ -15,49 +14,54 @@ template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> { struct GemmBasicTypeConfig<ck_tile::half_t>
using XDataType = ck_tile::half_t; {
using YDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using AccDataType = float; using YDataType = ck_tile::half_t;
using ODataType = ck_tile::half_t; //type convert using AccDataType = float;
using ODataType = ck_tile::half_t; // type convert
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
template<ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B, template <ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B, ck_tile::MatrixCLayout C>
ck_tile::MatrixCLayout C> struct LayoutConfig
struct LayoutConfig { {
static constexpr ck_tile::MatrixALayout LayoutA = A; static constexpr ck_tile::MatrixALayout LayoutA = A;
static constexpr ck_tile::MatrixBLayout LayoutB = B; static constexpr ck_tile::MatrixBLayout LayoutB = B;
static constexpr ck_tile::MatrixCLayout LayoutC = C; static constexpr ck_tile::MatrixCLayout LayoutC = C;
}; };
template<typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
template<> template <>
struct DataTypeTraits<float> { struct DataTypeTraits<float>
{
static constexpr const char* name = "float"; static constexpr const char* name = "float";
}; };
template<> template <>
struct DataTypeTraits<double> { struct DataTypeTraits<double>
{
static constexpr const char* name = "double"; static constexpr const char* name = "double";
}; };
template<> template <>
struct DataTypeTraits<ck_tile::half_t> { struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access // Specific type aliases for easy access
using XDataType = Types::XDataType; using XDataType = Types::XDataType;
using YDataType = Types::YDataType; using YDataType = Types::YDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using ODataType = Types::ODataType; using ODataType = Types::ODataType;
struct gemm_basic_args { struct gemm_basic_args
{
const void* p_x; const void* p_x;
const void* p_y; const void* p_y;
void* p_z; void* p_z;
......
...@@ -67,7 +67,7 @@ check_err(const Range& out, ...@@ -67,7 +67,7 @@ check_err(const Range& out,
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<double>::min(); double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = *std::next(std::begin(out), i); const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i); const double r = *std::next(std::begin(ref), i);
...@@ -127,7 +127,7 @@ check_err(const Range& out, ...@@ -127,7 +127,7 @@ check_err(const Range& out,
double err = 0; double err = 0;
// TODO: This is a hack. We should have proper specialization for bf16_t data type. // TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
...@@ -186,7 +186,7 @@ check_err(const Range& out, ...@@ -186,7 +186,7 @@ check_err(const Range& out,
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min()); double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
...@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const fp8_t o_fp8 = *std::next(std::begin(out), i); const fp8_t o_fp8 = *std::next(std::begin(out), i);
const fp8_t r_fp8 = *std::next(std::begin(ref), i); const fp8_t r_fp8 = *std::next(std::begin(ref), i);
...@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 4190; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
......
...@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::AccDataType,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType, Problem::BlockFmhaShape::BlockTile::kN0,
TileGemmShape<sequence< Problem::BlockFmhaShape::BlockTile::kK0>>;
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::OGradDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::OGradDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
typename Problem::AccDataType, Problem::BlockFmhaShape::BlockTile::kVHeaddim,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK1>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kVHeaddim,
Problem::BlockFmhaShape::BlockTile::kK1
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::OGradDataType,
BlockGemmPipelineProblem<typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::VDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK2>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK2
>,
Problem::BlockFmhaShape::Gemm2BlockWarps_,
Problem::BlockFmhaShape::Gemm2WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> && if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
...@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::QDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::QDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
typename Problem::AccDataType, Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK3>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK3
>,
Problem::BlockFmhaShape::Gemm3BlockWarps_,
Problem::BlockFmhaShape::Gemm3WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::GemmDataType,
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::KDataType, typename Problem::AccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::AccDataType, Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK4>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK4
>,
Problem::BlockFmhaShape::Gemm4BlockWarps_,
Problem::BlockFmhaShape::Gemm4WarpTile_>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
...@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::SaccDataType, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK0>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
typename Problem::KDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::SaccDataType, Problem::BlockFmhaShape::BlockTile::kN0,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK0>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0
>,
Problem::BlockFmhaShape::Gemm0BlockWarps_,
Problem::BlockFmhaShape::Gemm0WarpTile_>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -500,7 +486,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -500,7 +486,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
...@@ -555,7 +541,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -555,7 +541,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM #if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
...@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType,
BlockGemmPipelineProblem<typename Problem::PDataType, typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize,
typename Problem::VDataType, TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
typename Problem::OaccDataType, Problem::BlockFmhaShape::BlockTile::kN1,
Problem::kBlockSize, Problem::BlockFmhaShape::BlockTile::kK1>>;
TileGemmShape<sequence<
Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN1,
Problem::BlockFmhaShape::BlockTile::kK1
>,
Problem::BlockFmhaShape::Gemm1BlockWarps_,
Problem::BlockFmhaShape::Gemm1WarpTile_>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
std::is_same_v<typename Problem::CDataType, float>) std::is_same_v<typename Problem::CDataType, float>)
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} else { }
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution."); static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
} }
} }
......
...@@ -12,29 +12,34 @@ ...@@ -12,29 +12,34 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, typename Layouts_> template <typename TilePartitioner_,
struct GemmKernel { typename GemmPipeline_,
using TilePartitioner = remove_cvref_t<TilePartitioner_>; typename EpiloguePipeline_,
using GemmPipeline = remove_cvref_t<GemmPipeline_>; typename Layouts_>
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; struct GemmKernel
using Layouts = remove_cvref_t<Layouts_>; {
static constexpr index_t kBlockSize = GemmPipeline::kBlockSize; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using Layouts = remove_cvref_t<Layouts_>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; static constexpr index_t kBlockSize = GemmPipeline::kBlockSize;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size); auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size);
printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z); printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z);
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M_size, N_size, Batch_size);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
struct GemmCommonKargs { struct GemmCommonKargs
{
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
void* c_ptr; void* c_ptr;
...@@ -60,15 +65,19 @@ struct GemmKernel { ...@@ -60,15 +65,19 @@ struct GemmKernel {
ck_tile::index_t K, ck_tile::index_t K,
ck_tile::index_t stride_A, ck_tile::index_t stride_A,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C) { ck_tile::index_t stride_C)
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C}; {
return GemmCommonKargs{
a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const { CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_tile_m, i_tile_n, i_batch] = TilePartitioner{}(); const auto [i_tile_m, i_tile_n, i_batch] = TilePartitioner{}();
const index_t i_m = __builtin_amdgcn_readfirstlane(i_tile_m * TilePartitioner::kM); const index_t i_m = __builtin_amdgcn_readfirstlane(i_tile_m * TilePartitioner::kM);
const index_t i_n = __builtin_amdgcn_readfirstlane(i_tile_n * TilePartitioner::kN); const index_t i_n = __builtin_amdgcn_readfirstlane(i_tile_n * TilePartitioner::kN);
...@@ -76,66 +85,96 @@ struct GemmKernel { ...@@ -76,66 +85,96 @@ struct GemmKernel {
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views // Convert pointers to tensor views
auto a_tensor_view = [&](){ auto a_tensor_view = [&]() {
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) { if constexpr(Layouts::LayoutA == ck_tile::MatrixALayout::KM)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A), a_start,
number<GemmPipeline::AlignmentA>{}, number<1>{}); make_tuple(kargs.M, kargs.K),
} else { make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), a_start,
number<GemmPipeline::AlignmentA>{}, number<1>{}); make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
} }
}(); }();
auto b_tensor_view = [&](){ auto b_tensor_view = [&]() {
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) { if constexpr(Layouts::LayoutB == ck_tile::MatrixBLayout::KN)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), b_start,
number<GemmPipeline::AlignmentB>{}, number<1>{}); make_tuple(kargs.N, kargs.K),
} else { // Default NK layout make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
else
{ // Default NK layout
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), b_start,
number<GemmPipeline::AlignmentB>{}, number<1>{}); make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<1>{});
} }
}(); }();
auto ABlockWindow = make_tile_window(a_tensor_view, make_tuple(number<TilePartitioner::kM>{},
number<TilePartitioner::kK>{}), {i_m, 0});
auto ABlockWindow = make_tile_window(
auto BBlockWindow = make_tile_window(b_tensor_view, make_tuple(number<TilePartitioner::kN>{}, a_tensor_view,
number<TilePartitioner::kK>{}), {i_n, 0}); make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto BBlockWindow = make_tile_window(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(
auto acc = BlockGemmPipelineAGmemBGmemCRegV1<GemmPipeline>{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr); ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr); CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = [&](){ auto c_tensor_view = [&]() {
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM){ if constexpr(Layouts::LayoutC == ck_tile::MatrixCLayout::NM)
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), c_start,
number<GemmPipeline::AlignmentC>{}, number<1>{}); make_tuple(kargs.M, kargs.N),
} else { make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), c_start,
number<GemmPipeline::AlignmentC>{}, number<1>{}); make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
} }
}(); }();
auto CBlockWindow = make_tile_window(c_tensor_view, make_tuple(number<TilePartitioner::kM>{}, auto CBlockWindow = make_tile_window(
number<TilePartitioner::kN>{}), {i_m, i_n}); c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
// epilogue. // epilogue.
EpiloguePipeline{}(CBlockWindow, acc); EpiloguePipeline{}(CBlockWindow, acc);
} }
}; };
} } // namespace ck_tile
...@@ -4,36 +4,21 @@ ...@@ -4,36 +4,21 @@
#pragma once #pragma once
namespace ck_tile { namespace ck_tile {
enum struct MatrixALayout { enum struct MatrixALayout
MK, // Row-major layout for matrix A (default) {
KM // Column-major layout for matrix A MK, // Row-major layout for matrix A (default)
}; KM // Column-major layout for matrix A
};
enum struct MatrixBLayout { enum struct MatrixBLayout
NK, // Row-major layout for matrix B (default) {
KN // Column-major layout for matrix B NK, // Row-major layout for matrix B (default)
}; KN // Column-major layout for matrix B
};
enum struct MatrixCLayout { enum struct MatrixCLayout
MN, // Row-major layout for matrix C (default) {
NM // Column-major layout for matrix C MN, // Row-major layout for matrix C (default)
}; NM // Column-major layout for matrix C
};
// Function to convert string to MatrixALayout } // namespace ck_tile
inline MatrixALayout parse_layout_a(const std::string& layout) {
if (layout == "KM") return MatrixALayout::KM;
return MatrixALayout::MK; // Default to MK if not specified as KM
}
// Function to convert string to MatrixBLayout
inline MatrixBLayout parse_layout_b(const std::string& layout) {
if (layout == "KN") return MatrixBLayout::KN;
return MatrixBLayout::NK; // Default to NK if not specified as KN
}
// Function to convert string to MatrixBLayout
inline MatrixCLayout parse_layout_c(const std::string& layout) {
if (layout == "NM") return MatrixCLayout::NM;
return MatrixCLayout::MN; // Default to MN if not specified as NM
}
} // namespace ck_tile
...@@ -6,27 +6,30 @@ ...@@ -6,27 +6,30 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <typename BlockGemmShape_> template <typename BlockGemmShape_>
struct GemmTilePartitioner { struct GemmTilePartitioner
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>; {
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM; static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN; static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK; static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M, ck_tile::index_t N, CK_TILE_HOST static constexpr auto
ck_tile::index_t batch_size) { GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
ck_tile::index_t GridDimX = (M + kM - 1) / kM; {
ck_tile::index_t GridDimY = (N + kN - 1) / kN; ck_tile::index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimZ = batch_size; ck_tile::index_t GridDimY = (N + kN - 1) / kN;
return dim3(GridDimX, GridDimY, GridDimZ); ck_tile::index_t GridDimZ = batch_size;
} return dim3(GridDimX, GridDimY, GridDimZ);
}
CK_TILE_DEVICE auto operator()() { CK_TILE_DEVICE auto operator()()
const index_t i_GridDimX = blockIdx.x; {
const index_t i_GridDimY = blockIdx.y; const index_t i_GridDimX = blockIdx.x;
const index_t i_GridDimZ = blockIdx.z; const index_t i_GridDimY = blockIdx.y;
return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ); const index_t i_GridDimZ = blockIdx.z;
} return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ);
}; }
};
} // namespace ck_tile } // namespace ck_tile
...@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
...@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
while(iCounter > 0) { while(iCounter > 0)
{
// global read i + 1 // global read i + 1
a_block_tile = load_tile(a_copy_dram_window); a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window);
...@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--; iCounter--;
} }
// tail // tail
......
...@@ -91,26 +91,29 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -91,26 +91,29 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return b_lds_block_desc; return b_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b; return smem_size_b;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); {
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
index_t smem_size = 0; constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b; smem_size += smem_size_a + smem_size_b;
return smem_size; return smem_size;
......
...@@ -22,14 +22,13 @@ struct BlockGemmPipelineProblem ...@@ -22,14 +22,13 @@ struct BlockGemmPipelineProblem
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * 64; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * 64;
static constexpr bool kPadA = kPadA_; static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? 16 / sizeof(ADataType) : 1; static constexpr index_t AlignmentA = kPadA ? 16 / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? 16 / sizeof(BDataType) : 1; static constexpr index_t AlignmentB = kPadB ? 16 / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? 16 / sizeof(CDataType) : 1; static constexpr index_t AlignmentC = kPadC ? 16 / sizeof(CDataType) : 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -7,17 +7,21 @@ ...@@ -7,17 +7,21 @@
namespace ck_tile { namespace ck_tile {
template <typename BlockTile_, template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
typename BlockWarps_, struct TileGemmShape {
typename WarpTile_> static constexpr index_t kM = kMPerTile;
struct TileGemmShape static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShapeNewGemm
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>; using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>; using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{}); static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kN = BlockTile::at(number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment