operators.hpp 8.14 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP

Paul's avatar
Paul committed
4
#include <rtg/operation.hpp>
Paul's avatar
Paul committed
5
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
6
#include <rtg/streamutils.hpp>
Paul's avatar
Paul committed
7
#include <cmath>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
namespace rtg {

Paul's avatar
Paul committed
11
12
13
14
struct check_shapes
{
    const std::vector<shape>* shapes;

Paul's avatar
Paul committed
15
    check_shapes(const std::vector<shape>& s) : shapes(&s) {}
Paul's avatar
Paul committed
16
17
18
19
20

    const check_shapes& has(std::size_t n) const
    {
        assert(shapes != nullptr);
        if(shapes->size() != n)
Paul's avatar
Paul committed
21
22
            RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
                      std::to_string(shapes->size()));
Paul's avatar
Paul committed
23
24
25
26
27
28
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
        assert(shapes != nullptr);
Paul's avatar
Paul committed
29
30
        if(!shapes->empty())
        {
Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            if(shapes->front().lens().size() != n)
                RTG_THROW("Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            RTG_THROW("Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            RTG_THROW("Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            RTG_THROW("Dimensions do not match");
        return *this;
    }

Paul's avatar
Paul committed
58
    template <class F>
Paul's avatar
Paul committed
59
60
61
62
63
64
    bool same(F f) const
    {
        assert(shapes != nullptr);
        if(shapes->empty())
            return true;
        auto&& key = f(shapes->front());
Paul's avatar
Paul committed
65
        return this->all_of([&](const shape& s) { return f(s) == key; });
Paul's avatar
Paul committed
66
67
    }

Paul's avatar
Paul committed
68
    template <class Predicate>
Paul's avatar
Paul committed
69
70
71
72
73
74
75
    bool all_of(Predicate p) const
    {
        assert(shapes != nullptr);
        return std::all_of(shapes->begin(), shapes->end(), p);
    }
};

Paul's avatar
Paul committed
76
77
struct not_computable
{
Paul's avatar
Paul committed
78
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
79
80
};

Paul's avatar
Paul committed
81
struct convolution
Paul's avatar
Paul committed
82
{
Paul's avatar
Paul committed
83
84
85
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
86
87
88
89
90
91
92
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
93
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
94
95
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
96
97
        check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);

Paul's avatar
Paul committed
98
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
99
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
100
        auto t               = input.type();
Paul's avatar
Paul committed
101
102
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     weights.lens()[0],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
                 weights.lens()[0],
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
144
145
            RTG_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
146
    }
Paul's avatar
Paul committed
147

Paul's avatar
Paul committed
148
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
    friend std::ostream& operator<<(std::ostream& os, const convolution& op)
Paul's avatar
Paul committed
151
    {
Paul's avatar
Paul committed
152
153
154
155
156
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "dilation={" << stream_range(op.dilation) << "}";
        os << "]";
Paul's avatar
Paul committed
157
158
        return os;
    }
Paul's avatar
Paul committed
159
160
};

Paul's avatar
Paul committed
161
struct pooling
Paul's avatar
Paul committed
162
163
{
    std::string mode;
Paul's avatar
Paul committed
164
165
166
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
167
    std::string name() const { return "pooling"; }
Paul's avatar
Paul committed
168
169
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
170
        check_shapes{inputs}.has(1).only_dims(4);
Paul's avatar
Paul committed
171

Paul's avatar
Paul committed
172
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
173
        auto t             = input.type();
Paul's avatar
Paul committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
                        std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) /
                                  static_cast<float>(stride[0])) +
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
                        std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) /
                                  static_cast<float>(stride[1])) +
                            1)),
                }};
Paul's avatar
Paul committed
189
    }
Paul's avatar
Paul committed
190

Paul's avatar
Paul committed
191
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
192

Paul's avatar
Paul committed
193
    friend std::ostream& operator<<(std::ostream& os, const pooling& op)
Paul's avatar
Paul committed
194
    {
Paul's avatar
Paul committed
195
196
197
198
199
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "lengths={" << stream_range(op.lengths) << "}";
        os << "]";
Paul's avatar
Paul committed
200
201
        return os;
    }
Paul's avatar
Paul committed
202
203
};

Paul's avatar
Paul committed
204
struct activation
Paul's avatar
Paul committed
205
206
{
    std::string mode;
Paul's avatar
Paul committed
207
    std::string name() const { return "activation"; }
Paul's avatar
Paul committed
208
209
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
210
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
211
212
        return inputs.front();
    }
Paul's avatar
Paul committed
213

Paul's avatar
Paul committed
214
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
215
    friend std::ostream& operator<<(std::ostream& os, const activation& op)
Paul's avatar
Paul committed
216
    {
Paul's avatar
Paul committed
217
        os << op.name() << ":" << op.mode;
Paul's avatar
Paul committed
218
219
        return os;
    }
Paul's avatar
Paul committed
220
221
};

Paul's avatar
Paul committed
222
223
224
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
225
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
226
227
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
228
        if(inputs.empty())
Paul's avatar
Paul committed
229
            RTG_THROW("Wrong number of arguments");
Paul's avatar
Paul committed
230
231
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
Paul's avatar
Paul committed
232
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
233
234
235
236
237
238
239
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
240
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
241
242
243
244
        }
        return {inputs.front().type(), rdims};
    }

Paul's avatar
Paul committed
245
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
246

Paul's avatar
Paul committed
247
    friend std::ostream& operator<<(std::ostream& os, const reshape& op)
Paul's avatar
Paul committed
248
    {
Paul's avatar
Paul committed
249
250
251
        os << op.name() << "[";
        os << "dims={" << stream_range(op.dims) << "}, ";
        os << "]";
Paul's avatar
Paul committed
252
253
        return os;
    }
Paul's avatar
Paul committed
254
255
};

Paul's avatar
Paul committed
256
257
258
259
260
261
262
263
264
struct outline
{
    shape s;
    std::string name() const { return "outline"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(0);
        return s;
    }
Paul's avatar
Paul committed
265
    argument compute(shape, std::vector<argument>) const { return {s, nullptr}; }
Paul's avatar
Paul committed
266
267
};

Paul's avatar
Paul committed
268
269
270
} // namespace rtg

#endif