operators.hpp 7.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP

Paul's avatar
Paul committed
4
#include <rtg/operation.hpp>
Paul's avatar
Paul committed
5
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
6
#include <rtg/streamutils.hpp>
Paul's avatar
Paul committed
7
#include <cmath>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
namespace rtg {

Paul's avatar
Paul committed
11
12
13
14
struct check_shapes
{
    const std::vector<shape>* shapes;

Paul's avatar
Paul committed
15
    check_shapes(const std::vector<shape>& s) : shapes(&s) {}
Paul's avatar
Paul committed
16
17
18
19
20

    const check_shapes& has(std::size_t n) const
    {
        assert(shapes != nullptr);
        if(shapes->size() != n)
Paul's avatar
Paul committed
21
22
            RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
                      std::to_string(shapes->size()));
Paul's avatar
Paul committed
23
24
25
26
27
28
        return *this;
    }

    const check_shapes& only_dims(std::size_t n) const
    {
        assert(shapes != nullptr);
Paul's avatar
Paul committed
29
30
        if(!shapes->empty())
        {
Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            if(shapes->front().lens().size() != n)
                RTG_THROW("Only " + std::to_string(n) + "d supported");
        }
        return *this;
    }

    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
            RTG_THROW("Shapes do not match");
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
            RTG_THROW("Types do not match");
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
            RTG_THROW("Dimensions do not match");
        return *this;
    }

Paul's avatar
Paul committed
58
    template <class F>
Paul's avatar
Paul committed
59
60
61
62
63
64
    bool same(F f) const
    {
        assert(shapes != nullptr);
        if(shapes->empty())
            return true;
        auto&& key = f(shapes->front());
Paul's avatar
Paul committed
65
        return this->all_of([&](const shape& s) { return f(s) == key; });
Paul's avatar
Paul committed
66
67
    }

Paul's avatar
Paul committed
68
    template <class Predicate>
Paul's avatar
Paul committed
69
70
71
72
73
74
75
    bool all_of(Predicate p) const
    {
        assert(shapes != nullptr);
        return std::all_of(shapes->begin(), shapes->end(), p);
    }
};

Paul's avatar
Paul committed
76
77
struct not_computable
{
Paul's avatar
Paul committed
78
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
79
80
};

Paul's avatar
Paul committed
81
struct convolution
Paul's avatar
Paul committed
82
{
Paul's avatar
Paul committed
83
84
85
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
86
87
88
89
90
91
92
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
93
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
94
95
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
96
97
        check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);

Paul's avatar
Paul committed
98
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
99
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
100
        auto t               = input.type();
Paul's avatar
Paul committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        if (padding_mode == default_) {
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
        } else if(padding_mode == same) {
            return {t, {
                input.lens()[0], 
                weights.lens()[0], 
                static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))
            }};
        } else if(padding_mode == valid) {
            return {t, {
                input.lens()[0], 
                weights.lens()[0], 
                static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                static_cast<std::size_t>(std::ceil(static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))
            }};
        } else {
            RTG_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
136
    }
Paul's avatar
Paul committed
137

Paul's avatar
Paul committed
138
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
139

Paul's avatar
Paul committed
140
    friend std::ostream& operator<<(std::ostream& os, const convolution& op)
Paul's avatar
Paul committed
141
    {
Paul's avatar
Paul committed
142
143
144
145
146
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "dilation={" << stream_range(op.dilation) << "}";
        os << "]";
Paul's avatar
Paul committed
147
148
        return os;
    }
Paul's avatar
Paul committed
149
150
};

Paul's avatar
Paul committed
151
struct pooling
Paul's avatar
Paul committed
152
153
{
    std::string mode;
Paul's avatar
Paul committed
154
155
156
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
157
    std::string name() const { return "pooling"; }
Paul's avatar
Paul committed
158
159
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
160
        check_shapes{inputs}.has(1).only_dims(4);
Paul's avatar
Paul committed
161

Paul's avatar
Paul committed
162
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
163
        auto t             = input.type();
Paul's avatar
Paul committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
                        std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) /
                                  static_cast<float>(stride[0])) +
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
                        std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) /
                                  static_cast<float>(stride[1])) +
                            1)),
                }};
Paul's avatar
Paul committed
179
    }
Paul's avatar
Paul committed
180

Paul's avatar
Paul committed
181
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
182

Paul's avatar
Paul committed
183
    friend std::ostream& operator<<(std::ostream& os, const pooling& op)
Paul's avatar
Paul committed
184
    {
Paul's avatar
Paul committed
185
186
187
188
189
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "lengths={" << stream_range(op.lengths) << "}";
        os << "]";
Paul's avatar
Paul committed
190
191
        return os;
    }
Paul's avatar
Paul committed
192
193
};

Paul's avatar
Paul committed
194
struct activation
Paul's avatar
Paul committed
195
196
{
    std::string mode;
Paul's avatar
Paul committed
197
    std::string name() const { return "activation"; }
Paul's avatar
Paul committed
198
199
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
200
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
201
202
        return inputs.front();
    }
Paul's avatar
Paul committed
203

Paul's avatar
Paul committed
204
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
205
    friend std::ostream& operator<<(std::ostream& os, const activation& op)
Paul's avatar
Paul committed
206
    {
Paul's avatar
Paul committed
207
        os << op.name() << ":" << op.mode;
Paul's avatar
Paul committed
208
209
        return os;
    }
Paul's avatar
Paul committed
210
211
};

Paul's avatar
Paul committed
212
213
214
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
215
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
216
217
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
218
        if(inputs.empty())
Paul's avatar
Paul committed
219
            RTG_THROW("Wrong number of arguments");
Paul's avatar
Paul committed
220
221
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
Paul's avatar
Paul committed
222
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
223
224
225
226
227
228
229
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
230
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
231
232
233
234
        }
        return {inputs.front().type(), rdims};
    }

Paul's avatar
Paul committed
235
    argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
Paul's avatar
Paul committed
236

Paul's avatar
Paul committed
237
    friend std::ostream& operator<<(std::ostream& os, const reshape& op)
Paul's avatar
Paul committed
238
    {
Paul's avatar
Paul committed
239
240
241
        os << op.name() << "[";
        os << "dims={" << stream_range(op.dims) << "}, ";
        os << "]";
Paul's avatar
Paul committed
242
243
        return os;
    }
Paul's avatar
Paul committed
244
245
};

Paul's avatar
Paul committed
246
247
248
249
250
251
252
253
254
struct outline
{
    shape s;
    std::string name() const { return "outline"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(0);
        return s;
    }
Paul's avatar
Paul committed
255
    argument compute(shape, std::vector<argument>) const { return {s, nullptr}; }
Paul's avatar
Paul committed
256
257
};

Paul's avatar
Paul committed
258
259
260
} // namespace rtg

#endif