Unverified Commit feb9a2bd authored by jakpiase's avatar jakpiase Committed by GitHub
Browse files

Add IsSupportedArgument() to gemm_kernel (#1698)

* add IsSupportedArgument to gemm_kernel

* add ut and do some refactoring

* switched to ck_tile's integral_constant
parent d2d1d177
...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -66,6 +66,79 @@ struct GemmKernel ...@@ -66,6 +66,79 @@ struct GemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
{
return false;
}
if(kargs.K % GemmPipeline::VectorSizeA != 0)
{
return false;
}
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
{
return false;
}
if(kargs.M % GemmPipeline::VectorSizeA != 0)
{
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
{
return false;
}
if(kargs.N % GemmPipeline::VectorSizeB != 0)
{
return false;
}
}
else
{
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
{
return false;
}
if(kargs.K % GemmPipeline::VectorSizeB != 0)
{
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
{
return false;
}
if(kargs.N % GemmPipeline::VectorSizeC != 0)
{
return false;
}
}
else
{
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
{
return false;
}
if(kargs.M % GemmPipeline::VectorSizeC != 0)
{
return false;
}
}
return true;
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
......
...@@ -8,35 +8,29 @@ ...@@ -8,35 +8,29 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp" #include "test_gemm_mem_pipeline_util.hpp"
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
static constexpr auto Intrawave = ck_tile::GemmPipelineScheduler::Intrawave; ck_tile::GemmPipelineScheduler::Intrawave>;
static constexpr auto Interwave = ck_tile::GemmPipelineScheduler::Interwave; using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
template <typename Tuple>
class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline<Tuple, Intrawave>
{
};
template <typename Tuple>
class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline<Tuple, Interwave>
{
};
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std::tuple< Row, Col, Row, F16, F16, F32, F16>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Row, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16> std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave>
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes);
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_mem_pipeline_ut_cases.inc"
...@@ -3,11 +3,7 @@ ...@@ -3,11 +3,7 @@
#pragma once #pragma once
//------------------------------------------------------------------------------------------------ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -17,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM) ...@@ -17,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM) TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -27,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM) ...@@ -27,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK) TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -37,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK) ...@@ -37,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular) TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -47,46 +43,15 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular) ...@@ -47,46 +43,15 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular)
this->Run(M, N, K); this->Run(M, N, K);
} }
//------------------------------------------------------------------------------------------------ TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument)
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; constexpr int M = 512;
constexpr int N = 1024; constexpr int N = 1025;
constexpr int K = 320; constexpr int K = 513;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, MidLargeM) constexpr bool PadM = false;
{ constexpr bool PadN = false;
std::vector<int> Ms{127, 255, 312, 799, 1573}; constexpr bool PadK = false;
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, PaddK) EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
{
std::vector<int> Ms{127};
constexpr int N = 1024;
constexpr int K = 432;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 512;
for(int M : Ms)
this->Run(M, N, K);
} }
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
template <typename Tuple, ck_tile::GemmPipelineScheduler Scheduler_> template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmMemPipeline : public ::testing::Test
{ {
protected: protected:
...@@ -22,7 +22,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -22,7 +22,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_args struct gemm_args
...@@ -39,6 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -39,6 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::index_t stride_C; ck_tile::index_t stride_C;
}; };
template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
{ {
// TODO: This should be parameterized in tests // TODO: This should be parameterized in tests
...@@ -54,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -54,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadM = true; constexpr bool kPadM = PadM;
constexpr bool kPadN = true; constexpr bool kPadN = PadN;
constexpr bool kPadK = true; constexpr bool kPadK = PadK;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -107,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -107,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
...@@ -212,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -212,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
void SetUp() override { k_batches_ = {1}; } void SetUp() override { k_batches_ = {1}; }
template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M, void Run(const int M,
const int N, const int N,
const int K, const int K,
...@@ -221,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -221,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
{ {
for(auto kb : k_batches_) for(auto kb : k_batches_)
{ {
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, kb);
} }
} }
template <bool PadM, bool PadN, bool PadK>
void RunSingle(const int M, void RunSingle(const int M,
const int N, const int N,
const int K, const int K,
...@@ -301,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -301,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
invoke_gemm(args, ck_tile::stream_config{nullptr, false}); invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true; bool pass = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment