Commit 4525c5d7 authored by coderfeli's avatar coderfeli
Browse files

merge upstream

parents a8d88d8d 44828b7c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template <typename Tuple>
class TestCkTileBatchedGemm : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:
void Run(const int M,
const int N,
const int K,
int StrideA = 128,
int StrideB = 128,
int StrideC = 128,
const int BatchStrideA = 32768,
const int BatchStrideB = 16384,
const int BatchStrideC = 32768,
const int BatchCount = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC
<< " BatchStrideA =" << BatchStrideA << " BatchStrideB =" << BatchStrideB
<< " BatchStrideC =" << BatchStrideC << " BatchCount =" << BatchCount
<< std::endl;
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
EXPECT_TRUE(pass);
}
};
...@@ -11,8 +11,20 @@ ...@@ -11,8 +11,20 @@
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
static constexpr auto Intrawave = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = ck_tile::GemmPipelineScheduler::Interwave;
template <typename Tuple>
class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline<Tuple, Intrawave>
{
};
template <typename Tuple>
class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline<Tuple, Interwave>
{
};
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
...@@ -24,6 +36,7 @@ using KernelTypes = ::testing::Types< ...@@ -24,6 +36,7 @@ using KernelTypes = ::testing::Types<
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes);
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_mem_pipeline_ut_cases.inc"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) //------------------------------------------------------------------------------------------------
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 1024;
constexpr int K = 432;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 512;
for(int M : Ms)
this->Run(M, N, K);
}
//------------------------------------------------------------------------------------------------
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -10,7 +61,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) ...@@ -10,7 +61,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -20,7 +71,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) ...@@ -20,7 +71,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -30,7 +81,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) ...@@ -30,7 +81,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular) TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
constexpr int N = 1024; constexpr int N = 1024;
......
...@@ -11,20 +11,21 @@ ...@@ -11,20 +11,21 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
template <typename Tuple> template <typename Tuple, ck_tile::GemmPipelineScheduler Scheduler_>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmMemPipeline : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>; using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = Scheduler_;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_basic_args struct gemm_args
{ {
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
...@@ -38,7 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -38,7 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::index_t stride_C; ck_tile::index_t stride_C;
}; };
void invoke_gemm(const gemm_basic_args& args, const ck_tile::stream_config& s) void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
...@@ -89,7 +90,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -89,7 +90,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
ck_tile::GemmPipelineScheduler::Intrawave, Scheduler,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
...@@ -288,7 +289,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -288,7 +289,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
gemm_basic_args args; gemm_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
......
...@@ -6,12 +6,6 @@ if(result EQUAL 0) ...@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif() endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -10,25 +10,35 @@ ...@@ -10,25 +10,35 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; template <typename Tuple>
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>; class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; };
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// clang-format off
const std::vector<int> KBATCH{1, 2, 3, 5, 8}; using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Row, Col, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Col, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, std::tuple< Col, Col, Row, F16, F16, F16>,
RRR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, BF16, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, BF16, BF16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, std::tuple< Col, Row, Row, BF16, BF16, BF16>,
RCR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, I8, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc" #include "test_grouped_gemm_ut_cases.inc"
#pragma once #pragma once
TEST_P(RRR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TinyCases)
{ {
const std::vector<int> Ms{0, 1}; const std::vector<int> Ms{0, 1};
constexpr int N = 768; constexpr int N = 768;
...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases) ...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, SmallCases) TYPED_TEST(TestGroupedGemm, SmallCases)
{ {
const std::vector<int> Ms{2, 1, 3, 4, 5, 0}; const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768; constexpr int N = 768;
...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases) ...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MidCases) TYPED_TEST(TestGroupedGemm, MidCases)
{ {
const std::vector<int> Ms{167, 183, 177, 153, 139, 204}; const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768; constexpr int N = 768;
...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases) ...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, Regular) TYPED_TEST(TestGroupedGemm, Regular)
{ {
const std::vector<int> Ms{64, 128, 256}; const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768; constexpr int N = 768;
...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular) ...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MNKPadded) TYPED_TEST(TestGroupedGemm, MNKPadded)
{ {
const std::vector<int> Ms{127, 150, 188, 210}; const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136; constexpr int N = 136;
...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) ...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RCR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{ {
const std::vector<int> Ms{188, 210}; const std::vector<int> Ms{188, 210};
constexpr int N = 768; constexpr int N = 768;
...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) this->k_batches_ = {32, 64};
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck { namespace ck {
namespace test { namespace test {
...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range) ...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
} }
template <typename Tuple> template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int> class TestGroupedGemm : public testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public: public:
static constexpr bool verify_ = true; static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false; static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
void SetUp() override {} void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms, void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns, const std::vector<int>& Ns,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs = {})
int kbatch = 1, {
int n_warmup = 1, std::vector<int> stride_as = StrideAs;
int n_iter = 10) std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{ {
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup_,
n_iter); n_iter_);
EXPECT_TRUE(pass);
}
};
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
return ggemm_instance.IsSupportedArgument(argument); return ggemm_instance.IsSupportedArgument(argument);
...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker(); auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false}); return invoker.Run(argument, StreamConfig{nullptr, false});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment