Commit 12cd4c72 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add more test-cases for different data layout.

parent a01ca8e6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// #include <algorithm>
// #include <stdexcept>
#include <vector>
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
......@@ -34,17 +32,23 @@ class TestGemmSplitK_MK_KN
{
};
// template <typename Tuple>
// class TestGemmSplitK_MK_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Row, Col>,
// Tuple>::type> {};
template <typename Tuple>
class TestGemmSplitK_MK_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// template <typename Tuple>
// class TestGemmSplitK_KM_KN : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Row>,
// Tuple>::type> {};
template <typename Tuple>
class TestGemmSplitK_KM_KN
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
// template <typename Tuple>
// class TestGemmSplitK_KM_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Col>,
// Tuple>::type> {};
template <typename Tuple>
class TestGemmSplitK_KM_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
......@@ -55,8 +59,8 @@ using KernelTypes = ::testing::Types<
// clang-format on
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
#include "test_gemm_splitk_ut_cases.inc"
......@@ -13,3 +13,43 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideA = K;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment