"vscode:/vscode.git/clone" did not exist on "6db9f0282e2ab12795628de6200670892a8ad6ba"
conv1d_fwd.cpp 3.43 KB
Newer Older
1
2
3
4
#include <iostream>
#include <stdexcept>
#include <tuple>
#include <vector>
5
#include "gtest/gtest.h"
6
7
8

#include "data_type.hpp"
#include "element_wise_operation.hpp"
Adam Osewski's avatar
Adam Osewski committed
9
#include "library/include/ck/library/utility/conv_util.hpp"
10
#include "conv_util.hpp"
11
12
13

namespace {

14
15
template <typename T>
bool test_conv1d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
16
{
17
18
19
20
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;

21
    ck::utils::conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
22
23
24
25
26
27
28
    params.num_dim_spatial_        = 1;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{71};
    params.conv_filter_strides_    = std::vector<ck::index_t>{2};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1};
    params.input_left_pads_        = std::vector<ck::index_t>{1};
    params.input_right_pads_       = std::vector<ck::index_t>{1};
29

30
    conv::ConvFwdOpInstance<T, T, T, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(params);
31

32
33
34
    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
35
    return run_engine.Test(conv_ptrs);
36
37
}

38
39
40
} // anonymous namespace

TEST(Conv1DFwdNWC, TestConv1D)
41
{
42
43
44
45
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;

46
    ck::utils::conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
47
48
49
50
51
52
53
54
55
56
    params.num_dim_spatial_        = 1;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 4;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{16};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1};
    params.input_left_pads_        = std::vector<ck::index_t>{1};
    params.input_right_pads_       = std::vector<ck::index_t>{1};
57

58
59
60
61
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs);
    conv::ConvFwdOpInstance<float, float, float, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(
        params);
62

63
64
65
66
67
68
    auto reference_conv_fwd_fun = std::bind(
        conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3);
    OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-4);
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
69
70
}

71
TEST(Conv1DFwdNWC, Bf16Iinstances)
72
{
73
74
    EXPECT_TRUE(test_conv1d_nwc_instances<ck::bhalf_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<1>()));
75
76
}

77
TEST(Conv1DFwdNWC, F16Instances)
78
{
79
80
    EXPECT_TRUE(test_conv1d_nwc_instances<ck::half_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<1>()));
81
82
}

83
TEST(Conv1DFwdNWC, F32Instances)
84
{
85
86
    EXPECT_TRUE(test_conv1d_nwc_instances<float>(
        ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<1>()));
87
88
}

89
TEST(Conv1DFwdNWC, Int8Instances)
90
{
91
92
    EXPECT_TRUE(test_conv1d_nwc_instances<int8_t>(
        ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<1>()));
93
}