// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "gtest/gtest.h" #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "test_grouped_gemm_util.hpp" class TestGGemmInterface_MKNKMN : public ::testing::Test { protected: using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Row; using BLayout = Col; using ELayout = Row; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; template using GGemmInstance = ck::test::DeviceGroupedGemmInstanceWrapper; using DefaultGGemmInstance = GGemmInstance; }; TEST_F(TestGGemmInterface_MKNKMN, TileSize) { std::vector Ms{128, 256, 188, 512}; constexpr int N = 256; constexpr int K = 128; std::vector Ns(Ms.size(), N); std::vector Ks(Ms.size(), K); std::vector StrideAs(Ms.size(), K); std::vector StrideBs(Ms.size(), K); std::vector StrideCs(Ms.size(), N); // M % MPerBlock EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ms = std::vector{256, 128, 128, 512}; Ns = std::vector{256, 177, 128, 512}; // N % NPerBlock EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); } TEST_F(TestGGemmInterface_MKNKMN, VectorLoadWidth) { static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using PaddedGGemmInstance = GGemmInstance; std::vector Ms{128, 256, 256, 512}; constexpr int N = 256; constexpr int K = 512; std::vector Ns(Ms.size(), N); std::vector Ks(Ms.size(), K); std::vector StrideAs(Ms.size(), K); std::vector StrideBs(Ms.size(), K); std::vector StrideCs(Ms.size(), N); // K % ABlockTransferSrcScalarPerVector Ks = std::vector{256, 177, 128, 512}; EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ks = std::vector{256, 164, 128, 512}; // K % BBlockTransferSrcScalarPerVector EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ks = std::vector(4, 128); Ns = std::vector{256, 127, 128, 512}; // N % CBlockTransferScalarPerVector_NWaveNPerXDL EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); } class TestGGemmInterface_KMKNNM : public ::testing::Test { protected: using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Col; using BLayout = Row; using ELayout = Col; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; template using GGemmInstance = ck::test::DeviceGroupedGemmInstanceWrapper; using DefaultGGemmInstance = GGemmInstance; }; TEST_F(TestGGemmInterface_KMKNNM, TileSize) { std::vector Ms{128, 256, 188, 512}; constexpr int N = 256; constexpr int K = 128; std::vector Ns(Ms.size(), N); std::vector Ks(Ms.size(), K); std::vector StrideAs(Ms.size(), K); std::vector StrideBs(Ms.size(), K); std::vector StrideCs(Ms.size(), N); // M % MPerBlock EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ms = std::vector{128, 256, 256, 512}; Ns = std::vector{256, 177, 128, 512}; // N % NPerBlock EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); } TEST_F(TestGGemmInterface_KMKNNM, VectorLoadWidth) { static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using PaddedGGemmInstance = GGemmInstance; std::vector Ms{128, 256, 256, 512}; constexpr int N = 256; constexpr int K = 512; std::vector Ns(Ms.size(), N); std::vector Ks(Ms.size(), K); std::vector StrideAs(Ms.size(), K); std::vector StrideBs(Ms.size(), K); std::vector StrideCs(Ms.size(), N); // M % ABlockTransferSrcScalarPerVector Ms = std::vector{256, 177, 128, 512}; EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ms = std::vector{128, 256, 256, 512}; Ns = std::vector{256, 164, 128, 512}; // N % BBlockTransferSrcScalarPerVector EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); Ns = std::vector{128, 256, 256, 512}; Ms = std::vector{256, 130, 128, 512}; // M % CBlockTransferScalarPerVector_NWaveNPerXDL EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); }