// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/library/utility/host_common_util.hpp" #include "profiler/profile_reduce_impl.hpp" #include using namespace ck; struct ReduceParam { bool do_verification{true}; bool propagateNan{false}; bool useIndex{false}; bool time_kernel{false}; bool do_dumpout{false}; int init_method{2}; float alpha{1.0f}; float beta{0.0f}; std::vector inLengths{64, 4, 280, 82}; std::vector reduceDims{0, 1, 2, 3}; }; std::vector> SetGenericReduceDim() { return {{0, 1, 2, 3}, {0, 1, 2}, {0, 1, 3}, {0, 2, 3}, {1, 2, 3}, {0}, {1}, {2}, {3}}; } template class ReduceWithIndexTest : public ::testing::Test { protected: using InDataType = std::tuple_element_t<0, T>; using AccDataType = std::tuple_element_t<1, T>; using OutDataType = std::tuple_element_t<2, T>; static std::vector params; static void SetUpTestSuite() { // set testcase variables ReduceParam set; const auto setReduceDim = SetGenericReduceDim(); for(std::size_t i(0); i < setReduceDim.size(); ++i) { set.reduceDims = setReduceDim[i]; params.emplace_back(set); } } template void Run() { for(auto param : this->params) { bool success = ck::profiler::profile_reduce_impl( param.do_verification, param.init_method, param.do_dumpout, param.time_kernel, param.inLengths, param.reduceDims, ReduceOpIdType, param.propagateNan, param.useIndex, param.alpha, param.beta); EXPECT_TRUE(success); } } }; template std::vector ReduceWithIndexTest::params = {}; using Reduce_float_types = ::testing::Types>; using Reduce_double_types = ::testing::Types>; using Reduce_int8t_types = ::testing::Types>; using Reduce_half_types = ::testing::Types>; using Reduce_bhalf_float_Types = ::testing::Types>; template class ReduceWithNoIndexFloat : public ReduceWithIndexTest { }; template class ReduceWithNoIndexDouble : public ReduceWithIndexTest { }; template class ReduceWithNoIndexInt8 : public ReduceWithIndexTest { }; template class ReduceWithNoIndexHalf : public ReduceWithIndexTest { }; template class ReduceWithNoIndexBHalfFloat : public ReduceWithIndexTest { }; TYPED_TEST_SUITE(ReduceWithNoIndexFloat, Reduce_float_types); TYPED_TEST_SUITE(ReduceWithNoIndexDouble, Reduce_double_types); TYPED_TEST_SUITE(ReduceWithNoIndexInt8, Reduce_int8t_types); TYPED_TEST_SUITE(ReduceWithNoIndexHalf, Reduce_half_types); TYPED_TEST_SUITE(ReduceWithNoIndexBHalfFloat, Reduce_bhalf_float_Types); TYPED_TEST(ReduceWithNoIndexFloat, ReduceWithNoIndexTestFloat_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexFloat, ReduceWithNoIndexTestFloat_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexFloat, ReduceWithNoIndexTestFloat_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexDouble, ReduceWithNoIndexTestDouble_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexDouble, ReduceWithNoIndexTestDouble_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexDouble, ReduceWithNoIndexTestDouble_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexInt8, ReduceWithNoIndexTestInt8_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexInt8, ReduceWithNoIndexTestInt8_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexInt8, ReduceWithNoIndexTestInt8_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexHalf, ReduceWithNoIndexTestHalf_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexHalf, ReduceWithNoIndexTestHalf_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexHalf, ReduceWithNoIndexTestHalf_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexBHalfFloat, ReduceWithNoIndexTesBtHalfFloat_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexBHalfFloat, ReduceWithNoIndexTestBHalfFloat_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithNoIndexBHalfFloat, ReduceWithNoIndexTestBHalfFloat_MAX) { // trigger Run() -> Generic this->template Run(); }