Commit b4f65acf authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - naming changes, add problem

parent f79f727c
......@@ -15,7 +15,7 @@
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
......@@ -79,17 +79,17 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
auto kargs = Kernel::MakeKernelArgs(args.a_ptr,
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
......
......@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
......@@ -51,20 +52,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
......@@ -89,4 +76,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
float gemm_calc(ck_tile::GemmHostArgs args, const ck_tile::stream_config& s);
......@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
gemm_basic_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
......
......@@ -16,7 +16,7 @@
#include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
float batched_gemm(const BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
......@@ -79,9 +79,21 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
auto kargs = Kernel::MakeKernelArgs(args.a_ptr,
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
......
......@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType>
struct BatchedGemmTypeConfig;
......@@ -29,8 +30,36 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
auto create_args(int argc, char* argv[])
......@@ -60,4 +89,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s);
float batched_gemm(BatchedGemmHostArgs args, const ck_tile::stream_config& s);
......@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
batched_gemm_kargs args;
BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
......
......@@ -7,20 +7,12 @@
namespace ck_tile {
struct BatchedGemmHargs : GemmHargs
{
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKargs = typename Base::GemmKargs;
using GemmKernelArgs = typename Base::GemmKernelArgs;
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
......@@ -33,7 +25,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
struct BatchedGemmKargs : GemmKargs
struct BatchedGemmKernelArgs : GemmKernelArgs
{
index_t batch_stride_A;
index_t batch_stride_B;
......@@ -41,33 +33,34 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t batch_count;
};
using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHargs;
using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto GridSize(const Hargs& k)
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count)
{
return TilePartitioner::GridSize(k.M, k.N, k.batch_count);
return TilePartitioner::GridSize(M, N, batch_count);
}
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h)
CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
Kargs k;
k.a_ptr = h.a_ptr;
k.b_ptr = h.b_ptr;
k.c_ptr = h.c_ptr;
k.M = h.M;
k.N = h.N;
k.K = h.K;
k.stride_A = h.stride_A;
k.stride_B = h.stride_B;
k.stride_C = h.stride_C;
k.batch_stride_A = h.batch_stride_A;
k.batch_stride_B = h.batch_stride_B;
k.batch_stride_C = h.batch_stride_C;
k.batch_count = h.batch_count;
return k;
return BatchedGemmKernelArgs{{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C},
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -75,7 +68,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
......@@ -83,17 +76,17 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->RunGemm(a_start, b_start, c_start, kargs, i_m, i_n);
this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
}
};
......
......@@ -12,19 +12,6 @@
namespace ck_tile {
struct GemmHargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
......@@ -51,7 +38,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKargs
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
......@@ -64,17 +51,17 @@ struct GemmKernel
index_t stride_C;
};
CK_TILE_HOST static constexpr GemmKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{
return GemmKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -82,7 +69,7 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static bool IsSupportedArgument(const GemmKargs& kargs)
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
......@@ -158,7 +145,7 @@ struct GemmKernel
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKargs& kargs) const
const GemmKernelArgs& kargs) const
{
auto&& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
......@@ -311,14 +298,26 @@ struct GemmKernel
return make_tuple(a_block_window, b_block_window, c_block_window);
}
/**
* Create tensor views, pad views, tile windows, run gemm and epilogue pipeline
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m M block index
* @param block_idx_n N block index
*
* @return Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*/
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKargs& kargs,
const GemmKernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n) const
{
// Convert pointers to tensor views
// Create Gemm tensor views, pad views and tile windows
auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
......@@ -328,19 +327,18 @@ struct GemmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto a_block_window = gemm_tile_windows.at(I0);
auto b_block_window = gemm_tile_windows.at(I1);
// Run GEMM cooperatively by whole workgroup.
auto c_block_tile =
auto&& a_block_window = gemm_tile_windows.at(I0);
auto&& b_block_window = gemm_tile_windows.at(I1);
auto&& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
auto c_block_window = gemm_tile_windows.at(I2);
// Run CShuffle or Default 2D Epilogue
auto&& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile);
}
CK_TILE_DEVICE void operator()(GemmKargs kargs) const
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
namespace ck_tile {
struct Problem
{
CK_TILE_HOST Problem() = default;
CK_TILE_HOST Problem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public Problem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: Problem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
} // namespace ck_tile
......@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment