convolution.hpp 7.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP

27
#include <migraphx/op/common.hpp>
28
29
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
kahmed10's avatar
kahmed10 committed
30
#include <migraphx/value.hpp>
31
32
33
34
35
36
37
38
39
#include <cmath>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct convolution
{
40
41
42
    std::vector<std::size_t> padding  = {0, 0};
    std::vector<std::size_t> stride   = {1, 1};
    std::vector<std::size_t> dilation = {1, 1};
43

44
45
    int group                   = 1;
    padding_mode_t padding_mode = default_;
46
47
48
49
50
51
52

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
53
                    f(self.group, "group"),
54
                    f(self.padding_mode, "padding_mode"));
55
56
57
    }

    std::string name() const { return "convolution"; }
kahmed10's avatar
kahmed10 committed
58
59

    void check_attribute_size() const
60
    {
61
62
        if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
           stride.size() != dilation.size())
63
        {
Shucai Xiao's avatar
Shucai Xiao committed
64
            MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
65
        }
kahmed10's avatar
kahmed10 committed
66
67
    }

kahmed10's avatar
kahmed10 committed
68
69
70
    value attributes() const { return {{"normalize_padding", "padding"}}; }

    shape normalize_compute_shape(std::vector<shape> inputs) const
kahmed10's avatar
kahmed10 committed
71
    {
72
        check_shapes{inputs, *this, true}.has(2).same_type().same_ndims().min_ndims(3);
kahmed10's avatar
kahmed10 committed
73
        check_attribute_size();
74
75
76
        // num of dims of input and attribute should match
        const auto input_size   = inputs[0].max_lens().size();
        const auto padding_size = padding.size();
77
78

        if(input_size != padding_size / 2 + 2 && input_size != padding_size + 2)
Shucai Xiao's avatar
Shucai Xiao committed
79
80
81
        {
            MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
        }
82

83
84
85
86
        const shape& x_shape          = inputs.at(0);
        const shape& w_shape          = inputs.at(1);
        const size_t num_spatial_dims = input_size - 2;
        if(num_spatial_dims != this->kdims())
kahmed10's avatar
kahmed10 committed
87
        {
88
            MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
kahmed10's avatar
kahmed10 committed
89
        }
Khalique's avatar
Khalique committed
90

91
92
93
        if(not x_shape.dynamic() and not w_shape.dynamic() and
           x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
            MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        if(x_shape.dynamic() or w_shape.dynamic())
        {
            return dynamic_compute_shape(x_shape, w_shape);
        }
        else
        {
            return fixed_compute_shape(x_shape, w_shape);
        }
    }

    std::vector<std::size_t> calc_conv_lens(std::vector<std::size_t> x_lens,
                                            std::vector<std::size_t> w_lens) const
    {
        const size_t num_spatial_dims = x_lens.size() - 2;
        std::vector<size_t> ret       = {};
        // calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
        for(size_t i = 0; i < num_spatial_dims; i++)
        {
            if(x_lens[i] == 0 or w_lens[i] == 0)
            {
                // for handling when a dimension = 0 (opt of dynamic_dimension)
                ret.push_back(0);
            }
            else
            {
                auto padding_factor = 2 * padding[i];
                if(padding.size() == 2 * num_spatial_dims)
                {
                    // when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
                    padding_factor = padding[i] + padding[i + num_spatial_dims];
                }
                ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
                    1,
                    (x_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
                            stride[i] +
                        1)));
            }
        }
        return ret;
    }
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    shape dynamic_compute_shape(shape x_shape, shape w_shape) const
    {
        std::vector<shape::dynamic_dimension> output_dyn_dims = {};

        auto dynamic_shape_push_back = [&](const shape& input_shape) {
            if(input_shape.dynamic())
            {
                output_dyn_dims.push_back(input_shape.dyn_dims().at(0));
            }
            else
            {
                auto l = input_shape.lens().at(0);
                output_dyn_dims.push_back({l, l, 0});
            }
        };

        dynamic_shape_push_back(x_shape);
        dynamic_shape_push_back(w_shape);

        const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
156
        if(padding_mode != default_)
157
        {
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            for(std::size_t i = 0; i < num_spatial_dims; ++i)
            {
                auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; };
                auto s        = stride[i];
                if(x_shape.dynamic())
                {
                    auto x = x_shape.dyn_dims()[i + 2];
                    output_dyn_dims.push_back(shape::dynamic_dimension{
                        ceil_div(x.min, s), ceil_div(x.max, s), ceil_div(x.opt, s)});
                }
                else
                {
                    auto od = ceil_div(x_shape.lens()[i + 2], s);
                    output_dyn_dims.push_back(shape::dynamic_dimension{od, od, 0});
                }
            }
174
        }
175
176
177
178
179
180
181
182
183
184
185
186
187
        else
        {
            auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
            auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
            auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
            for(size_t i = 0; i < num_spatial_dims; ++i)
            {
                output_dyn_dims.push_back(shape::dynamic_dimension{
                    min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
            }
        }
        return shape{x_shape.type(), output_dyn_dims};
    }
188

189
190
191
192
193
194
195
196
    shape fixed_compute_shape(shape x_shape, shape w_shape) const
    {
        std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
        auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens());
        std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
            output_lens.push_back(x);
        });
        return x_shape.with_lens(output_lens);
197
    }
kahmed10's avatar
kahmed10 committed
198
199
200
201

    size_t kdims() const
    {
        check_attribute_size();
kahmed10's avatar
kahmed10 committed
202
        return stride.size();
kahmed10's avatar
kahmed10 committed
203
    }
204
205
206
207
208
209
210
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif