operators.hpp 4.33 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP

Paul's avatar
Paul committed
4
5
#include <rtg/operand.hpp>
#include <rtg/stringutils.hpp>
Paul's avatar
Paul committed
6
#include <cmath>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
namespace rtg {

Paul's avatar
Paul committed
10
11
12
13
struct not_computable
{
    argument compute(std::vector<argument>) const
    {
Paul's avatar
Paul committed
14
        throw std::runtime_error("not computable");
Paul's avatar
Paul committed
15
16
17
    }
};

Paul's avatar
Paul committed
18
struct convolution
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
27
28
29
30
31
{
    std::array<std::size_t, 2> padding = {0, 0};
    std::array<std::size_t, 2> stride = {1, 1};
    std::array<std::size_t, 2> dilation = {1, 1};
    std::string name() const
    {
        return "convolution[padding={" + to_string(padding) + 
            "}, stride={" + to_string(stride) +
            "}, dilation={" + to_string(dilation) +
            "}]";
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
32
        if(inputs.size() != 2) throw std::runtime_error("Wrong number of arguments");
Paul's avatar
Paul committed
33
34
        const shape& input = inputs.at(0);
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
35
36
37
        if(input.type() != weights.type()) throw std::runtime_error("Type doesn't match");
        if(input.lens().size() != weights.lens().size()) throw std::runtime_error("Dimensions don't match");
        if(input.lens().size() != 4) throw std::runtime_error("Only 4d convolution supported"); 
Paul's avatar
Paul committed
38
39
40

        auto t = input.type();
        return {t, {
Paul's avatar
Paul committed
41
42
43
44
45
46
            input.lens()[0],
            weights.lens()[0],
            std::size_t(std::max<std::ptrdiff_t>(
                1, (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + 2 * padding[0]) / stride[0] + 1)),
            std::size_t(std::max<std::ptrdiff_t>(
                1, (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + 2 * padding[1]) / stride[1] + 1)),
Paul's avatar
Paul committed
47
48
        }};
    }
Paul's avatar
Paul committed
49
50
51
52
53

    argument compute(std::vector<argument>) const
    {
        throw std::runtime_error("not computable");
    }
Paul's avatar
Paul committed
54
55
};

Paul's avatar
Paul committed
56
struct pooling
Paul's avatar
Paul committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
{
    std::string mode;
    std::array<std::size_t, 2> padding = {0, 0};
    std::array<std::size_t, 2> stride = {1, 1};
    std::array<std::size_t, 2> lengths = {1, 1};
    std::string name() const
    {
        return "pooling:" + mode + "[padding={" + to_string(padding) + 
            "}, stride={" + to_string(stride) +
            "}, lengths={" + to_string(lengths) +
            "}]";
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
71
        if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
Paul's avatar
Paul committed
72
        const shape& input = inputs.at(0);    
Paul's avatar
Paul committed
73
        if(input.lens().size() != 4) throw std::runtime_error("Only 4d pooling supported"); 
Paul's avatar
Paul committed
74
75
76

        auto t = input.type();
        return {t, {
Paul's avatar
Paul committed
77
78
79
80
81
82
            input.lens()[0],
            input.lens()[1],
            std::size_t(std::max<std::ptrdiff_t>(
                1, std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1)),
            std::size_t(std::max<std::ptrdiff_t>(
                1, std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1)),
Paul's avatar
Paul committed
83
84
        }};
    }
Paul's avatar
Paul committed
85
86
87
88
89

    argument compute(std::vector<argument>) const
    {
        throw std::runtime_error("not computable");
    }
Paul's avatar
Paul committed
90
91
92
};


Paul's avatar
Paul committed
93
struct activation
Paul's avatar
Paul committed
94
95
96
97
98
99
100
101
{
    std::string mode;
    std::string name() const
    {
        return "activation:" + mode;
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
102
        if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
Paul's avatar
Paul committed
103
104
        return inputs.front();
    }
Paul's avatar
Paul committed
105
106
107
108
109

    argument compute(std::vector<argument>) const
    {
        throw std::runtime_error("not computable");
    }
Paul's avatar
Paul committed
110
111
};

Paul's avatar
Paul committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
struct reshape
{
    std::vector<int64_t> dims;
    std::string name() const
    {
        return "reshape[dims={" + to_string(dims) +
            "}]";
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
        if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
        for(std::size_t i = 0;i < dims.size();i++)
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
        if(dims.back() == -1)
        {
            rdims.pop_back();
            std::copy(idims.begin()+rdims.size(), idims.end(), std::back_inserter(rdims));
        }
        return {inputs.front().type(), rdims};
    }

    argument compute(std::vector<argument>) const
    {
        throw std::runtime_error("not computable");
    }
};

Paul's avatar
Paul committed
144

Paul's avatar
Paul committed
145
146
147
} // namespace rtg

#endif