// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include "profiler/include/profile_conv_bwd_data_impl.hpp" template class TestConvndBwdData : public ::testing::Test { protected: using DataType = std::tuple_element_t<0, Tuple>; std::vector conv_params; template void Run() { for(auto& param : conv_params) { bool pass; EXPECT_FALSE(conv_params.empty()); pass = ck::profiler::profile_conv_bwd_data_impl< NDimSpatial, ck::tuple_element_t>, ck::tuple_element_t>, ck::tuple_element_t>, DataType, DataType, DataType>(true, // do_verification 1, // init_method integer value false, // do_log false, // time_kernel param); EXPECT_TRUE(pass); } } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestConvndBwdData, KernelTypes); // 1d TYPED_TEST(TestConvndBwdData, Conv1dBwdData) { this->conv_params.clear(); this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->template Run<1>(); } // 2d TYPED_TEST(TestConvndBwdData, Conv2dBwdData) { this->conv_params.clear(); this->conv_params.push_back( {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); } // 3d TYPED_TEST(TestConvndBwdData, Conv3dBwdData) { this->conv_params.clear(); this->conv_params.push_back( {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->template Run<3>(); }