Commit b4f65acf authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - naming changes, add problem

parent f79f727c
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
...@@ -79,17 +79,17 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -79,17 +79,17 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args.a_ptr,
args.p_b, args.b_ptr,
args.p_c, args.c_ptr,
args.M, args.M,
args.N, args.N,
args.K, args.K,
args.stride_A, args.stride_A,
args.stride_B, args.stride_B,
args.stride_C); args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
...@@ -51,20 +52,6 @@ using BDataType = Types::BDataType; ...@@ -51,20 +52,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -89,4 +76,4 @@ auto create_args(int argc, char* argv[]) ...@@ -89,4 +76,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(ck_tile::GemmHostArgs args, const ck_tile::stream_config& s);
...@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
gemm_basic_args args; ck_tile::GemmHostArgs args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch; args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "batched_gemm.hpp" #include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) float batched_gemm(const BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
...@@ -79,9 +79,21 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& ...@@ -79,9 +79,21 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKernelArgs(args.a_ptr,
args.b_ptr,
const dim3 grids = Kernel::GridSize(args); args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(s.log_level_ > 0)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template <typename DataType> template <typename DataType>
struct BatchedGemmTypeConfig; struct BatchedGemmTypeConfig;
...@@ -29,8 +30,36 @@ using BDataType = Types::BDataType; ...@@ -29,8 +30,36 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{ {
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
}; };
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
...@@ -60,4 +89,4 @@ auto create_args(int argc, char* argv[]) ...@@ -60,4 +89,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s); float batched_gemm(BatchedGemmHostArgs args, const ck_tile::stream_config& s);
...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
batched_gemm_kargs args; BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
......
...@@ -7,20 +7,12 @@ ...@@ -7,20 +7,12 @@
namespace ck_tile { namespace ck_tile {
struct BatchedGemmHargs : GemmHargs
{
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_> struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{ {
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>; using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKargs = typename Base::GemmKargs; using GemmKernelArgs = typename Base::GemmKernelArgs;
using ADataType = typename Base::ADataType; using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType; using BDataType = typename Base::BDataType;
...@@ -33,7 +25,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -33,7 +25,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout; using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout; using CLayout = typename Base::CLayout;
struct BatchedGemmKargs : GemmKargs struct BatchedGemmKernelArgs : GemmKernelArgs
{ {
index_t batch_stride_A; index_t batch_stride_A;
index_t batch_stride_B; index_t batch_stride_B;
...@@ -41,33 +33,34 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -41,33 +33,34 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t batch_count; index_t batch_count;
}; };
using Kargs = BatchedGemmKargs; using KernelArgs = BatchedGemmKernelArgs;
using Hargs = BatchedGemmHargs;
__host__ static constexpr auto GridSize(const Hargs& k) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count)
{ {
return TilePartitioner::GridSize(k.M, k.N, k.batch_count); return TilePartitioner::GridSize(M, N, batch_count);
} }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{ {
Kargs k; return BatchedGemmKernelArgs{{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C},
k.a_ptr = h.a_ptr; batch_stride_A,
k.b_ptr = h.b_ptr; batch_stride_B,
k.c_ptr = h.c_ptr; batch_stride_C,
k.M = h.M; batch_count};
k.N = h.N;
k.K = h.K;
k.stride_A = h.stride_A;
k.stride_B = h.stride_B;
k.stride_C = h.stride_C;
k.batch_stride_A = h.batch_stride_A;
k.batch_stride_B = h.batch_stride_B;
k.batch_stride_C = h.batch_stride_C;
k.batch_count = h.batch_count;
return k;
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -75,7 +68,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -75,7 +68,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
...@@ -83,17 +76,17 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -83,17 +76,17 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A; const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B; const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C; CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
this->RunGemm(a_start, b_start, c_start, kargs, i_m, i_n); this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
} }
}; };
......
...@@ -12,19 +12,6 @@ ...@@ -12,19 +12,6 @@
namespace ck_tile { namespace ck_tile {
struct GemmHargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel struct GemmKernel
{ {
...@@ -51,7 +38,7 @@ struct GemmKernel ...@@ -51,7 +38,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKargs struct GemmKernelArgs
{ {
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
...@@ -64,17 +51,17 @@ struct GemmKernel ...@@ -64,17 +51,17 @@ struct GemmKernel
index_t stride_C; index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmKargs MakeKargs(const void* a_ptr, CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
const void* b_ptr, const void* b_ptr,
void* c_ptr, void* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t stride_C) index_t stride_C)
{ {
return GemmKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
...@@ -82,7 +69,7 @@ struct GemmKernel ...@@ -82,7 +69,7 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_HOST static bool IsSupportedArgument(const GemmKargs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -158,7 +145,7 @@ struct GemmKernel ...@@ -158,7 +145,7 @@ struct GemmKernel
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
const GemmKargs& kargs) const const GemmKernelArgs& kargs) const
{ {
auto&& a_tensor_view = [&]() { auto&& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
...@@ -311,14 +298,26 @@ struct GemmKernel ...@@ -311,14 +298,26 @@ struct GemmKernel
return make_tuple(a_block_window, b_block_window, c_block_window); return make_tuple(a_block_window, b_block_window, c_block_window);
} }
/**
* Create tensor views, pad views, tile windows, run gemm and epilogue pipeline
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m M block index
* @param block_idx_n N block index
*
* @return Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*/
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
const GemmKargs& kargs, const GemmKernelArgs& kargs,
const index_t block_idx_m, const index_t block_idx_m,
const index_t block_idx_n) const const index_t block_idx_n) const
{ {
// Convert pointers to tensor views // Create Gemm tensor views, pad views and tile windows
auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
...@@ -328,19 +327,18 @@ struct GemmKernel ...@@ -328,19 +327,18 @@ struct GemmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto a_block_window = gemm_tile_windows.at(I0);
auto b_block_window = gemm_tile_windows.at(I1);
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
auto c_block_tile = auto&& a_block_window = gemm_tile_windows.at(I0);
auto&& b_block_window = gemm_tile_windows.at(I1);
auto&& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
auto c_block_window = gemm_tile_windows.at(I2); // Run CShuffle or Default 2D Epilogue
auto&& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(c_block_window, c_block_tile);
} }
CK_TILE_DEVICE void operator()(GemmKargs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
// options // options
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
namespace ck_tile {
struct Problem
{
CK_TILE_HOST Problem() = default;
CK_TILE_HOST Problem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public Problem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: Problem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
} // namespace ck_tile
...@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHargs struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{ {
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment