"docs/source/en/vscode:/vscode.git/clone" did not exist on "ad06e5106e704bed612f587e86f4f58d711e181f"
Commit 12cd4c72 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add more test-cases for different data layout.

parent a01ca8e6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// #include <algorithm> #include <tuple>
// #include <stdexcept>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -34,17 +32,23 @@ class TestGemmSplitK_MK_KN ...@@ -34,17 +32,23 @@ class TestGemmSplitK_MK_KN
{ {
}; };
// template <typename Tuple> template <typename Tuple>
// class TestGemmSplitK_MK_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Row, Col>, class TestGemmSplitK_MK_NK
// Tuple>::type> {}; : public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// template <typename Tuple> template <typename Tuple>
// class TestGemmSplitK_KM_KN : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Row>, class TestGemmSplitK_KM_KN
// Tuple>::type> {}; : public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
// template <typename Tuple> template <typename Tuple>
// class TestGemmSplitK_KM_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Col>, class TestGemmSplitK_KM_NK
// Tuple>::type> {}; : public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
...@@ -55,8 +59,8 @@ using KernelTypes = ::testing::Types< ...@@ -55,8 +59,8 @@ using KernelTypes = ::testing::Types<
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes); TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes); TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes); TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes); TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
#include "test_gemm_splitk_ut_cases.inc" #include "test_gemm_splitk_ut_cases.inc"
...@@ -13,3 +13,43 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) ...@@ -13,3 +13,43 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC); this->Run(M, N, K, StrideA, StrideB, StrideC);
} }
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideA = K;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideB = N;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
int N = 512;
int K = 320;
int StrideB = K;
int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment