// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "include/ck/utility/data_type.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" namespace ck { namespace test { template std::string serialize_range(const Range& range) { std::stringstream ss; for(auto& r : range) { ss << r << ", "; } std::string str = ss.str(); return std::string(str.begin(), str.end() - 2); } template class TestGroupedGemm : public testing::TestWithParam { protected: using ALayout = std::tuple_element_t<0, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>; using ELayout = std::tuple_element_t<2, Tuple>; using ADataType = std::tuple_element_t<3, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>; public: bool verify_ = true; int init_method_ = 0; // decimal value initialization bool log_ = false; bool bench_ = false; // measure kernel performance void SetUp() override {} void Run(const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, int kbatch = 1) { bool pass = ck::profiler::profile_grouped_gemm_impl( verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); EXPECT_TRUE(pass); } }; } // namespace test } // namespace ck