Commit c2b7f8df authored by ThomasNing's avatar ThomasNing
Browse files

Finished the Matrix Layout feature set up. Note: Need to modify the inner...

Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future.
parent 1b61d467
......@@ -20,9 +20,6 @@ auto create_args(int argc, char* argv[]) {
.insert("m", "1024", "m dimension")
.insert("n", "2048", "n dimension")
.insert("k", "32", "k dimension")
.insert("layoutA", "MK", "matrix A layout")
.insert("layoutB", "NK", "matrix B layout")
.insert("layoutC", "MN", "matrix C layout")
.insert("stride_a", "0", "stride on apply the m,k A block")
.insert("stride_b", "0", "stride on apply the n,k B block")
.insert("stride_c", "0", "stride on apply the m,n C block")
......@@ -34,15 +31,16 @@ auto create_args(int argc, char* argv[]) {
.insert("e", "1e-5", "epsilon")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("following_op", "no", "combined_op. bias/relu/gelu...")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s) {
template <typename Layouts>
float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
// ToDo: This will be modified by the codegen code later.
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
......@@ -79,11 +77,11 @@ float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s) {
ODataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs(
args.p_x, args.p_y, args.p_z, args.batch_size, args.epsilon, args.M, args.N,
args.K, args.stride_A, args.stride_B, args.stride_C, args.layout_a
args.K, args.stride_A, args.stride_B, args.stride_C
);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_size);
......@@ -95,11 +93,10 @@ float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s) {
return ave_time;
}
template <typename DataType>
template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
ck_tile::DeviceMem& z_buf,
const ck_tile::ArgParser& arg_parser,
const ck_tile::MatrixALayout matrix_a_layout){
const ck_tile::ArgParser& arg_parser){
std::string data_type = arg_parser.get_str("prec");
......@@ -128,30 +125,45 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
args.M = M;
args.N = N;
args.K = K;
args.layout_a = matrix_a_layout;
// args.layout_b = layout_b;
// args.layout_c = layout_c;
// Only set stride_M and stride_N if they are non-zero and not equal to K
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if (stride_a != 0) {
args.stride_A = stride_a;
} else {
args.stride_A = K;
args.stride_A = [&](){
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) {
return M;
} else {
return K;
}
}();
}
if (stride_b != 0) {
args.stride_B = stride_b;
} else {
args.stride_B = K;
args.stride_B = [&](){
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) {
return N;
} else {
return K;
}
}();
}
if(stride_c != 0) {
args.stride_C = stride_c;
} else {
args.stride_C = N;
args.stride_C = [&](){
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM) {
return M;
} else {
return N;
}
}();
}
float ave_time = gemm_calc(args, ck_tile::stream_config{nullptr, true});
float ave_time = gemm_calc<Layouts>(args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * K + sizeof(YDataType) * N * K+
sizeof(ODataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
......@@ -170,19 +182,17 @@ int main(int argc, char* argv[]) {
if(!result)
return -1;
std::string layout_a = arg_parser.get_str("layoutA");
std::string layout_b = arg_parser.get_str("layoutB");
std::string layout_c = arg_parser.get_str("layoutC");
ck_tile::MatrixALayout matrix_a_layout = ck_tile::parse_layout_a(layout_a);
ck_tile::MatrixBLayout matrix_b_layout = ck_tile::parse_layout_b(layout_b);
ck_tile::MatrixCLayout matrix_c_layout = ck_tile::parse_layout_c(layout_c);
bool grouped_enable = arg_parser.get_bool("grouped");
std::string following_op_descrp = arg_parser.get_str("following_op");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK;
constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK;
constexpr ck_tile::MatrixCLayout matrix_c_layout = ck_tile::MatrixCLayout::MN;
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK) ?
std::vector<int>{M, K} : std::vector<int>{K, M};
......@@ -197,8 +207,6 @@ int main(int argc, char* argv[]) {
ck_tile::HostTensor<ODataType> z_host_ref(z_dimensions);
ck_tile::HostTensor<ODataType> z_host_dev(z_dimensions);
// ck_tile::FillConstant<XDataType>{1.f}(x_host);
// ck_tile::FillConstant<YDataType>{1.f}(y_host);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(y_host);
......@@ -214,7 +222,7 @@ int main(int argc, char* argv[]) {
return -1;
}
OperatorExecution<ck_tile::half_t>(x_buf, y_buf, z_buf, arg_parser, matrix_a_layout);
OperatorExecution<ck_tile::half_t, Layouts>(x_buf, y_buf, z_buf, arg_parser);
bool pass = true;
......
......@@ -23,6 +23,14 @@ struct GemmBasicTypeConfig<ck_tile::half_t> {
// ToDo: Add more bias config to support different categories of GEMM.
};
template<ck_tile::MatrixALayout A, ck_tile::MatrixBLayout B,
ck_tile::MatrixCLayout C>
struct LayoutConfig {
static constexpr ck_tile::MatrixALayout LayoutA = A;
static constexpr ck_tile::MatrixBLayout LayoutB = B;
static constexpr ck_tile::MatrixCLayout LayoutC = C;
};
template<typename T>
struct DataTypeTraits;
......@@ -54,9 +62,6 @@ struct gemm_basic_args {
const void* p_y;
void* p_z;
float epsilon;
ck_tile::MatrixALayout layout_a;
// std::string layout_b;
// std::string layout_c;
ck_tile::index_t batch_size;
ck_tile::index_t M;
ck_tile::index_t N;
......
......@@ -12,20 +12,20 @@
namespace ck_tile {
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, typename Layouts_>
struct GemmKernel {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = GemmPipeline::kBlockSize;
using ADataType = ck_tile::remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = ck_tile::remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = ck_tile::remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = ck_tile::remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(ck_tile::index_t M_size, ck_tile::index_t N_size,
ck_tile::index_t Batch_size) {
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using Layouts = remove_cvref_t<Layouts_>;
static constexpr index_t kBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) {
auto x = TilePartitioner::GridSize(M_size, N_size, Batch_size);
printf("GridDimX: %d, GridDimY: %d, %d", x.x, x.y, x.z);
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
......@@ -48,8 +48,6 @@ struct GemmKernel {
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
MatrixALayout layout_A;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
......@@ -62,9 +60,8 @@ struct GemmKernel {
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
MatrixALayout layout_A) {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C, layout_A};
ck_tile::index_t stride_C) {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, batch_size, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() {
......@@ -79,17 +76,33 @@ struct GemmKernel {
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = (kargs.layout_A == MatrixALayout::MK) ?
make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), number<GemmPipeline::AlignmentA>{}, number<1>{}) :
make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.K, kargs.M), make_tuple(1, kargs.stride_A), number<GemmPipeline::AlignmentA>{}, number<1>{});
auto b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), number<GemmPipeline::AlignmentB>{}, number<1>{});
auto a_tensor_view = [&](){
if constexpr (Layouts::LayoutA == ck_tile::MatrixALayout::KM) {
return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{}, number<1>{});
} else {
return make_naive_tensor_view<address_space_enum::global>(
a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{}, number<1>{});
}
}();
auto b_tensor_view = [&](){
if constexpr (Layouts::LayoutB == ck_tile::MatrixBLayout::KN) {
return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{}, number<1>{});
} else { // Default NK layout
return make_naive_tensor_view<address_space_enum::global>(
b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{}, number<1>{});
}
}();
auto ABlockWindow = make_tile_window(a_tensor_view, make_tuple(number<TilePartitioner::kM>{},
number<TilePartitioner::kK>{}), {i_m, 0});
auto BBlockWindow = make_tile_window(b_tensor_view, make_tuple(number<TilePartitioner::kN>{},
number<TilePartitioner::kK>{}), {i_n, 0});
......@@ -105,9 +118,17 @@ struct GemmKernel {
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{}, number<1>{});
auto c_tensor_view = [&](){
if constexpr (Layouts::LayoutC == ck_tile::MatrixCLayout::NM){
return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{}, number<1>{});
} else {
return make_naive_tensor_view<address_space_enum::global>(
c_start, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{}, number<1>{});
}
}();
auto CBlockWindow = make_tile_window(c_tensor_view, make_tuple(number<TilePartitioner::kM>{},
number<TilePartitioner::kN>{}), {i_m, i_n});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment