conv_util.cpp 6.31 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <string>
#include <vector>
#include <gtest/gtest.h>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"

#include "ck/library/utility/check_err.hpp"
13
#include "ck/library/utility/convolution_parameter.hpp"
Chao Liu's avatar
Chao Liu committed
14
15
16
17
18
19

namespace {

class TestConvUtil : public ::testing::Test
{
    public:
20
    void SetNDParams(std::size_t ndims, std::size_t s, std::size_t d, std::size_t p)
Chao Liu's avatar
Chao Liu committed
21
    {
22
23
24
25
26
27
28
29
30
31
32
        conv_params = ck::utils::conv::ConvParam(ndims,
                                                 2,
                                                 128,
                                                 192,
                                                 256,
                                                 std::vector<ck::index_t>(ndims, 3),
                                                 std::vector<ck::index_t>(ndims, 71),
                                                 std::vector<ck::index_t>(ndims, s),
                                                 std::vector<ck::index_t>(ndims, d),
                                                 std::vector<ck::index_t>(ndims, p),
                                                 std::vector<ck::index_t>(ndims, p));
Chao Liu's avatar
Chao Liu committed
33
34
35
36
    }

    protected:
    // -------  default 2D -------
37
38
39
40
41
42
    // input GNCHW {2, 128, 192, 71, 71},
    // weights GKCYX {2, 256, 192, 3, 3},
    // stride {s, s},
    // dilations {d, d},
    // padding {{p, p}, {p, p}
    ck::utils::conv::ConvParam conv_params;
Chao Liu's avatar
Chao Liu committed
43
44
45
46
};

} // namespace

47
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
Chao Liu's avatar
Chao Liu committed
48
{
49
50
    // stride 2, dilation 1, pad 1
    SetNDParams(1, 2, 1, 1);
Chao Liu's avatar
Chao Liu committed
51
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
52
53
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
Chao Liu's avatar
Chao Liu committed
54

55
56
57
    // stride 1, dilation 1, pad 1
    SetNDParams(1, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
58
    EXPECT_TRUE(ck::utils::check_err(
59
        out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
Chao Liu's avatar
Chao Liu committed
60

61
62
63
    // stride 2, dilation 1, pad 2
    SetNDParams(1, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
64
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
65
66
                                     std::vector<ck::index_t>{37},
                                     "Error: ConvParams 1D padding left/right {2}."));
Chao Liu's avatar
Chao Liu committed
67

68
69
70
    // stride 2, dilation 2, pad 2
    SetNDParams(1, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
71
    EXPECT_TRUE(ck::utils::check_err(
72
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
Chao Liu's avatar
Chao Liu committed
73

74
75
76
    // stride 3, dilation 2, pad 1
    SetNDParams(1, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
77
78
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
79
80
                             std::vector<ck::index_t>{23},
                             "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
Chao Liu's avatar
Chao Liu committed
81
82
}

83
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
Chao Liu's avatar
Chao Liu committed
84
{
85
86
    // stride 2, dilation 1, pad 1
    SetNDParams(2, 2, 1, 1);
Chao Liu's avatar
Chao Liu committed
87
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
88
89
90
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36},
                                     "Error: ConvParams 2D default constructor."));
Chao Liu's avatar
Chao Liu committed
91

92
93
94
    // stride 1, dilation 1, pad 1
    SetNDParams(2, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
95
    EXPECT_TRUE(ck::utils::check_err(
96
        out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
Chao Liu's avatar
Chao Liu committed
97

98
99
100
    // stride 2, dilation 1, pad 2
    SetNDParams(2, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
101
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
102
103
                                     std::vector<ck::index_t>{37, 37},
                                     "Error: ConvParams 2D padding left/right {2,2}."));
Chao Liu's avatar
Chao Liu committed
104

105
106
107
    // stride 2, dilation 2, pad 2
    SetNDParams(2, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
108
    EXPECT_TRUE(ck::utils::check_err(
109
        out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
Chao Liu's avatar
Chao Liu committed
110

111
112
113
    // stride 3, dilation 2, pad 1
    SetNDParams(2, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
114
115
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
116
117
                             std::vector<ck::index_t>{23, 23},
                             "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
Chao Liu's avatar
Chao Liu committed
118
119
120
121
}

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{
122
123
    // stride 2, dilation 1, pad 1
    SetNDParams(3, 2, 1, 1);
Chao Liu's avatar
Chao Liu committed
124
125
126
127
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));

128
129
130
    // stride 1, dilation 1, pad 1
    SetNDParams(3, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
131
132
133
134
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{71, 71, 71},
                                     "Error: ConvParams 3D stride {1, 1, 1}."));

135
136
137
    // stride 2, dilation 1, pad 2
    SetNDParams(3, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
138
139
140
141
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37, 37, 37},
                                     "Error: ConvParams 3D padding left/right {2, 2, 2}."));

142
143
144
    // stride 2, dilation 2, pad 2
    SetNDParams(3, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
145
146
147
148
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36, 36},
                                     "Error: ConvParams 3D dilation {2, 2, 2}."));

149
150
151
    // stride 3, dilation 2, pad 1
    SetNDParams(3, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
152
153
154
155
156
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len,
        std::vector<ck::index_t>{23, 23, 23},
        "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
}