// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/library/utility/host_common_util.hpp" #include "profiler/profile_reduce_impl.hpp" #include using namespace ck; struct ReduceParam { bool do_verification{true}; bool propagateNan{false}; bool useIndex{false}; bool time_kernel{false}; bool do_dumpout{false}; int init_method{2}; float alpha{1.0f}; float beta{0.0f}; std::vector inLengths{64, 4, 280, 82}; std::vector reduceDims{0, 1, 2, 3}; }; std::vector> SetGenericReduceDim() { return {{0, 1, 2, 3}, {0, 1, 2}, {0, 1, 3}, {0, 2, 3}, {1, 2, 3}, {0}, {1}, {2}, {3}}; } template class ReduceWithIndexTest : public ::testing::Test { protected: using InDataType = std::tuple_element_t<0, T>; using AccDataType = std::tuple_element_t<1, T>; using OutDataType = std::tuple_element_t<2, T>; static std::vector params; static void SetUpTestSuite() { // set testcase variables ReduceParam set; const auto setReduceDim = SetGenericReduceDim(); for(std::size_t i(0); i < setReduceDim.size(); ++i) { set.reduceDims = setReduceDim[i]; params.emplace_back(set); } } template void Run() { for(auto param : this->params) { bool success = ck::profiler::profile_reduce_impl( param.do_verification, param.init_method, param.do_dumpout, param.time_kernel, param.inLengths, param.reduceDims, ReduceOpIdType, param.propagateNan, param.useIndex, param.alpha, param.beta); EXPECT_TRUE(success); } } }; template std::vector ReduceWithIndexTest::params = {}; using Reduce_float_types = ::testing::Types>; using Reduce_double_types = ::testing::Types>; using Reduce_int8t_types = ::testing::Types>; using Reduce_half_types = ::testing::Types>; using Reduce_bhalf_float_Types = ::testing::Types>; template class ReduceWithIndexFloat : public ReduceWithIndexTest { }; template class ReduceWithIndexDouble : public ReduceWithIndexTest { }; template class ReduceWithIndexInt8 : public ReduceWithIndexTest { }; template class ReduceWithIndexHalf : public ReduceWithIndexTest { }; template class ReduceWithIndexBHalfFloat : public ReduceWithIndexTest { }; TYPED_TEST_SUITE(ReduceWithIndexFloat, Reduce_float_types); TYPED_TEST_SUITE(ReduceWithIndexDouble, Reduce_double_types); TYPED_TEST_SUITE(ReduceWithIndexInt8, Reduce_int8t_types); TYPED_TEST_SUITE(ReduceWithIndexHalf, Reduce_half_types); TYPED_TEST_SUITE(ReduceWithIndexBHalfFloat, Reduce_bhalf_float_Types); TYPED_TEST(ReduceWithIndexFloat, ReduceWithIndexTestFloat_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexFloat, ReduceWithIndexTestFloat_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexFloat, ReduceWithIndexTestFloat_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexDouble, ReduceWithIndexTestDouble_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexDouble, ReduceWithIndexTestDouble_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexDouble, ReduceWithIndexTestDouble_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexInt8, ReduceWithIndexTestInt8_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexInt8, ReduceWithIndexTestInt8_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexInt8, ReduceWithIndexTestInt8_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexHalf, ReduceWithIndexTestHalf_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexHalf, ReduceWithIndexTestHalf_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexHalf, ReduceWithIndexTestHalf_MAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexBHalfFloat, ReduceWithIndexTesBtHalfFloat_AMAX) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexBHalfFloat, ReduceWithIndexTestBHalfFloat_MIN) { // trigger Run() -> Generic this->template Run(); } TYPED_TEST(ReduceWithIndexBHalfFloat, ReduceWithIndexTestBHalfFloat_MAX) { // trigger Run() -> Generic this->template Run(); }