Commit 2298a1a4 authored by illsilin's avatar illsilin
Browse files

sync from public

parents 965b7ba4 2f088b87
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_batched_gemm_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileBatchedGemm, KernelTypes);
#include "test_batched_gemm_ut_cases.inc"
#pragma once
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 128;
constexpr int K = 128;
this->Run(M, N, K);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template <typename Tuple>
class TestCkTileBatchedGemm : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:
void Run(const int M,
const int N,
const int K,
int StrideA = 128,
int StrideB = 128,
int StrideC = 128,
const int BatchStrideA = 32768,
const int BatchStrideB = 16384,
const int BatchStrideC = 32768,
const int BatchCount = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC
<< " BatchStrideA =" << BatchStrideA << " BatchStrideB =" << BatchStrideB
<< " BatchStrideC =" << BatchStrideC << " BatchCount =" << BatchCount
<< std::endl;
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
EXPECT_TRUE(pass);
}
};
......@@ -8,19 +8,26 @@
#include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave>
>;
// clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
......@@ -39,3 +42,16 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument)
{
constexpr int M = 512;
constexpr int N = 1025;
constexpr int K = 513;
constexpr bool PadM = false;
constexpr bool PadN = false;
constexpr bool PadK = false;
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
}
......@@ -15,16 +15,17 @@ template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
// TODO: expose tile size through test t-param ?
struct gemm_basic_args
struct gemm_args
{
const void* p_a;
const void* p_b;
......@@ -38,7 +39,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::index_t stride_C;
};
void invoke_gemm(const gemm_basic_args& args, const ck_tile::stream_config& s)
template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128;
......@@ -53,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadM = true;
constexpr bool kPadN = true;
constexpr bool kPadK = true;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr int kBlockPerCu = 1;
......@@ -89,7 +91,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
AccDataType,
GemmShape,
Traits,
ck_tile::GemmPipelineScheduler::Intrawave,
Scheduler,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......@@ -106,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......@@ -211,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
void SetUp() override { k_batches_ = {1}; }
template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M,
const int N,
const int K,
......@@ -220,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
template <bool PadM, bool PadN, bool PadK>
void RunSingle(const int M,
const int N,
const int K,
......@@ -288,7 +297,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
gemm_basic_args args;
gemm_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
......@@ -300,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args.stride_B = stride_B;
args.stride_C = stride_C;
invoke_gemm(args, ck_tile::stream_config{nullptr, false});
invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
......
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
TYPED_TEST(TestCkTileGroupedGemm, Basic)
{
const int group_count = 16;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template <typename Tuple>
class TestCkTileGroupedGemm : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct GroupedGemKernelParam
{
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 32;
static const ck_tile::index_t M_Warp = 2;
static const ck_tile::index_t N_Warp = 2;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 8;
};
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
GroupedGemKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue =
std::conditional_t<std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kTilePermute,
GroupedGemKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_)
{
using GroupedGemmKernel = Kernel<ALayout, BLayout, CLayout>;
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(
p_workspace_,
arguments.data(),
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
GroupedGemmKernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
std::vector<int>& stride_As,
std::vector<int>& stride_Bs,
std::vector<int>& stride_Cs,
const int group_count = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], ALayout{});
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
}
EXPECT_TRUE(pass);
}
};
......@@ -51,8 +51,11 @@ TEST(Custom_bool, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bool_t>()(Number<i>{}) = custom_bool_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bool_t, size> left_vec{right_vec};
vector_type<custom_bool_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_bool_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bool_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -129,8 +132,11 @@ TEST(Custom_int8, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_int8_t>()(Number<i>{}) = custom_int8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_int8_t, size> left_vec{right_vec};
vector_type<custom_int8_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_int8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_int8_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -207,8 +213,11 @@ TEST(Custom_uint8, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_uint8_t>()(Number<i>{}) = custom_uint8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_uint8_t, size> left_vec{right_vec};
vector_type<custom_uint8_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_uint8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -287,8 +296,11 @@ TEST(Custom_f8, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_f8_t>()(Number<i>{}) = custom_f8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_f8_t, size> left_vec{right_vec};
vector_type<custom_f8_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_f8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_f8_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -369,8 +381,11 @@ TEST(Custom_bf8, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bf8_t>()(Number<i>{}) = custom_bf8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bf8_t, size> left_vec{right_vec};
vector_type<custom_bf8_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_bf8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -450,8 +465,11 @@ TEST(Custom_half, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_half_t>()(Number<i>{}) = custom_half_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_half_t, size> left_vec{right_vec};
vector_type<custom_half_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_half_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_half_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -533,8 +551,11 @@ TEST(Custom_bhalf, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bhalf_t>()(Number<i>{}) = custom_bhalf_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bhalf_t, size> left_vec{right_vec};
vector_type<custom_bhalf_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_bhalf_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -615,8 +636,11 @@ TEST(Custom_float, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_float_t>()(Number<i>{}) = custom_float_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_float_t, size> left_vec{right_vec};
vector_type<custom_float_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_float_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_float_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -693,8 +717,11 @@ TEST(Custom_double, TestAsType)
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_double_t>()(Number<i>{}) = custom_double_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_double_t, size> left_vec{right_vec};
vector_type<custom_double_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<custom_double_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_double_t>()(Number<i>{}).data, test_vec.at(i));
......@@ -813,8 +840,11 @@ TEST(Complex_half, TestAsType)
right_vec.template AsType<complex_half_t>()(Number<i>{}) =
complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)};
});
// copy the vector
vector_type<complex_half_t, size> left_vec{right_vec};
vector_type<complex_half_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<complex_half_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).real,
......@@ -907,8 +937,11 @@ TEST(FP8OCP, TestAsType)
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
});
// copy the vector
vector_type<f8_t, size> left_vec{right_vec};
vector_type<f8_t, size> left_vec;
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<f8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -984,8 +1017,11 @@ TEST(BF8OCP, TestAsType)
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
});
// copy the vector
vector_type<bf8_t, size> left_vec{right_vec};
// check copy assignment op
left_vec = right_vec;
// overwrite right_vec with 0s
right_vec = vector_type<bf8_t, size>{};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
......
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp)
add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
if(result EQUAL 0)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template <typename Tuple>
class TestGroupedConvndBwdDataWmma : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
using OutLayout = std::tuple_element_t<1, Tuple>;
using WeiLayout = std::tuple_element_t<2, Tuple>;
using InLayout = std::tuple_element_t<3, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
template <ck::index_t NDimSpatial>
void Run()
{
EXPECT_FALSE(conv_params.empty());
bool pass = true;
for(auto& param : conv_params)
{
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
OutLayout,
WeiLayout,
InLayout,
DataType,
DataType,
DataType>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param);
}
EXPECT_TRUE(pass);
}
};
using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
using KernelTypes3d = ::testing::Types<std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
template <typename Tuple>
class TestGroupedConvndBwdDataWmma2d : public TestGroupedConvndBwdDataWmma<Tuple>
{
};
template <typename Tuple>
class TestGroupedConvndBwdDataWmma3d : public TestGroupedConvndBwdDataWmma<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdDataWmma2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndBwdDataWmma3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......@@ -12,7 +12,7 @@
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template <typename Tuple>
class TestGroupedConvndBwdData : public ::testing::Test
class TestGroupedConvndBwdDataXdl : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
......@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>,
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
template <typename Tuple>
class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl<Tuple>
{
};
template <typename Tuple>
class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple>
class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d);
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
{
this->conv_params.clear();
......@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
// SplitN case
this->conv_params.push_back(
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
......@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// SplitN case
this->conv_params.push_back({3,
1,
128,
4,
192,
{2, 2, 2},
{2, 224, 224},
{1, 224, 224},
{1, 1, 1},
{0, 0, 0},
{0, 0, 0}});
this->template Run<3>();
}
......@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
......@@ -10,25 +10,35 @@
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
template <typename Tuple>
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16>,
std::tuple< Row, Col, Row, F16, F16, F16>,
std::tuple< Col, Row, Row, F16, F16, F16>,
std::tuple< Col, Col, Row, F16, F16, F16>,
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
std::tuple< Col, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
TEST_P(RRR_F16_F16_F16, TinyCases)
TYPED_TEST(TestGroupedGemm, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
......@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, SmallCases)
TYPED_TEST(TestGroupedGemm, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
......@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, MidCases)
TYPED_TEST(TestGroupedGemm, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
......@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, Regular)
TYPED_TEST(TestGroupedGemm, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
......@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, MNKPadded)
TYPED_TEST(TestGroupedGemm, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
......@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
......@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->k_batches_ = {32, 64};
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
......@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck {
namespace test {
......@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
}
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
class TestGroupedGemm : public testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
......@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
void SetUp() override {}
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
......@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
kbatches,
n_warmup_,
n_iter_);
EXPECT_TRUE(pass);
}
};
......@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
......@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment