#include #include #include #include "gtest/gtest.h" #include "data_type.hpp" #include "element_wise_operation.hpp" #include "library/include/ck/library/utility/conv_util.hpp" #include "conv_util.hpp" namespace { class Conv1dFwdNWCInstances : public ::testing::Test { public: template bool test_conv1d_nwc_instances(const std::vector& conv_ptrs, const ck::utils::conv::ConvParams& params) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; conv::ConvFwdOpInstance, FillUniformDistributionIntegerValue> conv_instance(params, true, FillUniformDistributionIntegerValue{}, FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(atol_); run_engine.SetRtol(rtol_); return run_engine.Test(conv_ptrs); } template bool test_default() { return test_conv1d_nwc_instances( ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_default_); } template bool test_filter1x1_stride1_pad0() { return test_conv1d_nwc_instances( ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_filter1x1_stride1_pad0_); } template bool test_filter1x1_pad0() { return test_conv1d_nwc_instances( ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_filter1x1_pad0_); } static inline ck::utils::conv::ConvParams params_default_{ 1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}}; static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ 1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}}; static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ 1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}}; private: double atol_{1e-5}; double rtol_{1e-4}; }; } // anonymous namespace TEST(Conv1DFwdNWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; using T = float; ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; std::vector conv_ptrs; test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs); conv::ConvFwdOpInstance, FillUniformDistributionIntegerValue> conv_instance(params, true, FillUniformDistributionIntegerValue{}, FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(1e-5); run_engine.SetRtol(1e-4); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } TEST(Conv1DFwdNWC, FloatingPointValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; using T = ck::half_t; ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; std::vector conv_ptrs; test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs); conv::ConvFwdOpInstance, FillUniformDistribution> conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(0.1); run_engine.SetRtol(1e-2); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0) { EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); } TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0) { EXPECT_TRUE(this->test_filter1x1_pad0()); } TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0) { EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); } TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0) { EXPECT_TRUE(this->test_filter1x1_pad0()); } TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0) { EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); } TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0) { EXPECT_TRUE(this->test_filter1x1_pad0()); } TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0) { EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); } TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0) { EXPECT_TRUE(this->test_filter1x1_pad0()); }