conv_util.cpp 6.55 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8
9
10
11
12

#include <iostream>
#include <string>
#include <vector>
#include <gtest/gtest.h>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"

#include "ck/library/utility/check_err.hpp"
13
#include "ck/library/utility/convolution_parameter.hpp"
Chao Liu's avatar
Chao Liu committed
14
15
16
17
18
19

namespace {

class TestConvUtil : public ::testing::Test
{
    public:
20
    void SetNDParams(std::size_t ndims, std::size_t s, std::size_t d, std::size_t p)
Chao Liu's avatar
Chao Liu committed
21
    {
22
23
24
25
26
        conv_params = ck::utils::conv::ConvParam(ndims,
                                                 2,
                                                 128,
                                                 192,
                                                 256,
27
28
29
30
31
32
                                                 std::vector<ck::long_index_t>(ndims, 3),
                                                 std::vector<ck::long_index_t>(ndims, 71),
                                                 std::vector<ck::long_index_t>(ndims, s),
                                                 std::vector<ck::long_index_t>(ndims, d),
                                                 std::vector<ck::long_index_t>(ndims, p),
                                                 std::vector<ck::long_index_t>(ndims, p));
Chao Liu's avatar
Chao Liu committed
33
34
35
36
    }

    protected:
    // -------  default 2D -------
37
38
39
40
41
42
    // input GNCHW {2, 128, 192, 71, 71},
    // weights GKCYX {2, 256, 192, 3, 3},
    // stride {s, s},
    // dilations {d, d},
    // padding {{p, p}, {p, p}
    ck::utils::conv::ConvParam conv_params;
Chao Liu's avatar
Chao Liu committed
43
44
45
46
};

} // namespace

47
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
Chao Liu's avatar
Chao Liu committed
48
{
49
50
    // stride 2, dilation 1, pad 1
    SetNDParams(1, 2, 1, 1);
51
    std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
52
    EXPECT_TRUE(ck::utils::check_err(
53
        out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D."));
Chao Liu's avatar
Chao Liu committed
54

55
56
57
    // stride 1, dilation 1, pad 1
    SetNDParams(1, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
58
    EXPECT_TRUE(ck::utils::check_err(
59
        out_spatial_len, std::vector<ck::long_index_t>{71}, "Error: ConvParams 1D stride {1}."));
Chao Liu's avatar
Chao Liu committed
60

61
62
63
    // stride 2, dilation 1, pad 2
    SetNDParams(1, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
64
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
65
                                     std::vector<ck::long_index_t>{37},
66
                                     "Error: ConvParams 1D padding left/right {2}."));
Chao Liu's avatar
Chao Liu committed
67

68
69
70
    // stride 2, dilation 2, pad 2
    SetNDParams(1, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
71
    EXPECT_TRUE(ck::utils::check_err(
72
        out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D dilation {2}."));
Chao Liu's avatar
Chao Liu committed
73

74
75
76
    // stride 3, dilation 2, pad 1
    SetNDParams(1, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
77
78
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
79
                             std::vector<ck::long_index_t>{23},
80
                             "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
Chao Liu's avatar
Chao Liu committed
81
82
}

83
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
Chao Liu's avatar
Chao Liu committed
84
{
85
86
    // stride 2, dilation 1, pad 1
    SetNDParams(2, 2, 1, 1);
87
    std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
88
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
89
                                     std::vector<ck::long_index_t>{36, 36},
90
                                     "Error: ConvParams 2D default constructor."));
Chao Liu's avatar
Chao Liu committed
91

92
93
94
    // stride 1, dilation 1, pad 1
    SetNDParams(2, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
95
96
97
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::long_index_t>{71, 71},
                                     "Error: ConvParams 2D stride {1,1}."));
Chao Liu's avatar
Chao Liu committed
98

99
100
101
    // stride 2, dilation 1, pad 2
    SetNDParams(2, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
102
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
103
                                     std::vector<ck::long_index_t>{37, 37},
104
                                     "Error: ConvParams 2D padding left/right {2,2}."));
Chao Liu's avatar
Chao Liu committed
105

106
107
108
    // stride 2, dilation 2, pad 2
    SetNDParams(2, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
109
110
111
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::long_index_t>{36, 36},
                                     "Error: ConvParams 2D dilation {2,2}."));
Chao Liu's avatar
Chao Liu committed
112

113
114
115
    // stride 3, dilation 2, pad 1
    SetNDParams(2, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
116
117
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
118
                             std::vector<ck::long_index_t>{23, 23},
119
                             "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
Chao Liu's avatar
Chao Liu committed
120
121
122
123
}

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{
124
125
    // stride 2, dilation 1, pad 1
    SetNDParams(3, 2, 1, 1);
126
    std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
127
    EXPECT_TRUE(ck::utils::check_err(
128
        out_spatial_len, std::vector<ck::long_index_t>{36, 36, 36}, "Error: ConvParams 3D."));
Chao Liu's avatar
Chao Liu committed
129

130
131
132
    // stride 1, dilation 1, pad 1
    SetNDParams(3, 1, 1, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
133
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
134
                                     std::vector<ck::long_index_t>{71, 71, 71},
Chao Liu's avatar
Chao Liu committed
135
136
                                     "Error: ConvParams 3D stride {1, 1, 1}."));

137
138
139
    // stride 2, dilation 1, pad 2
    SetNDParams(3, 2, 1, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
140
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
141
                                     std::vector<ck::long_index_t>{37, 37, 37},
Chao Liu's avatar
Chao Liu committed
142
143
                                     "Error: ConvParams 3D padding left/right {2, 2, 2}."));

144
145
146
    // stride 2, dilation 2, pad 2
    SetNDParams(3, 2, 2, 2);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
147
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
148
                                     std::vector<ck::long_index_t>{36, 36, 36},
Chao Liu's avatar
Chao Liu committed
149
150
                                     "Error: ConvParams 3D dilation {2, 2, 2}."));

151
152
153
    // stride 3, dilation 2, pad 1
    SetNDParams(3, 3, 2, 1);
    out_spatial_len = conv_params.GetOutputSpatialLengths();
Chao Liu's avatar
Chao Liu committed
154
155
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len,
156
        std::vector<ck::long_index_t>{23, 23, 23},
Chao Liu's avatar
Chao Liu committed
157
158
        "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
}