"...composable_kernel_rocm.git" did not exist on "aa46039f2d25874c799d85f5cd28bc636892deef"
Commit 5bedd21a authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Refactor GemmKernel - update tests

parent 13fe6e95
...@@ -95,7 +95,7 @@ struct GemmKernel ...@@ -95,7 +95,7 @@ struct GemmKernel
index_t stride_C; index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(GemmHostArgs& hostArgs) CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
{ {
return GemmKernelArgs{hostArgs.a_ptr, return GemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr, hostArgs.b_ptr,
......
...@@ -91,19 +91,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -91,19 +91,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using Kernel = using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args.a_ptr, auto kargs = Kernel::MakeKernelArgs(args);
args.b_ptr,
args.c_ptr,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
......
...@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
template <bool PadM, bool PadN, bool PadK> template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
...@@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>>; tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
...@@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
gemm_args args; ck_tile::GemmHostArgs args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch; args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment