Commit 8385597f authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Refactor GemmKernel - update tests

parent 75535dd8
......@@ -30,38 +30,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
......
......@@ -299,7 +299,8 @@ struct GemmKernel
}
/**
* Create tensor views, pad views, tile windows, run gemm and epilogue pipeline
* Create tensor views, pad views, tile windows.
* Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
......@@ -307,8 +308,6 @@ struct GemmKernel
* @param kargs GEMM kernel arguments
* @param block_idx_m M block index
* @param block_idx_n N block index
*
* @return Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*/
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
......
......@@ -53,4 +53,36 @@ struct GemmHostArgs : public Problem
index_t k_batch;
};
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
} // namespace ck_tile
......@@ -11,6 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename Tuple>
class TestCkTileBatchedGemm : public ::testing::Test
......@@ -24,12 +25,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
......@@ -94,9 +92,21 @@ class TestCkTileBatchedGemm : public ::testing::Test
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
auto kargs = Kernel::MakeKernelArgs(args.a_ptr,
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
......@@ -185,21 +195,22 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.M = M;
args.N = N;
args.K = K;
args.stride_A = StrideA;
args.stride_B = StrideB;
args.stride_C = StrideC;
args.batch_stride_A = BatchStrideA;
args.batch_stride_B = BatchStrideB;
args.batch_stride_C = BatchStrideC;
args.batch_count = BatchCount;
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
......
......@@ -95,7 +95,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
auto kargs = Kernel::MakeKernelArgs(args.p_a,
args.p_b,
args.p_c,
args.M,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment