cpu_target.cpp 3.55 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

#include <rtg/cpu/cpu_target.hpp>
#include <rtg/instruction.hpp>
#include <rtg/dfor.hpp>
#include <rtg/operators.hpp>

namespace rtg { namespace cpu {

struct cpu_convolution
{
    convolution op;

    std::string name() const
    {
        return "cpu::convolution";
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
        return op.compute_shape(inputs);
    }
    argument compute(std::vector<argument> args) const
    {
        shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()});
        argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})};
        result.visit([&](auto output) {
            args[0].visit([&](auto input) {
                args[1].visit([&](auto weights) {
                    auto in_n = input.get_shape().lens()[0];
                    auto in_c = input.get_shape().lens()[1];
                    auto in_h = input.get_shape().lens()[2];
                    auto in_w = input.get_shape().lens()[3];

                    auto wei_c = weights.get_shape().lens()[1];
                    auto wei_h = weights.get_shape().lens()[2];
                    auto wei_w = weights.get_shape().lens()[3];

                    dfor(in_n, in_c, in_h, in_w)([&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
                        const int start_x = i * op.stride[0] - op.padding[0];
                        const int start_y = j * op.stride[1] - op.padding[1];

                        double acc = 0;
                        dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
                            const int in_x = start_x + x;
                            const int in_y = start_y + y;
                            if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
                            {
                                acc += input(o, k, in_x, in_y) * weights(w, k, x, y);
                            }
                        });
                        output(o, w, i, j) = acc;
                    });

                });
            });
        });
        return result;
    }
};

struct relu
{
    std::string name() const
    {
        return "cpu::relu";
    }
    shape compute_shape(std::vector<shape> inputs) const
    {
        return inputs.front();
    }

    argument compute(std::vector<argument> args) const 
    {
        argument result{args[0].get_shape()};
        result.visit([&](auto output) {
            args[0].visit([&](auto input) {
                std::transform(input.begin(), input.end(), output.begin(), [](auto x) {
                    return x > 0 ? x : 0;
                });
            });
        });
        return result;
    }
};

struct cpu_apply
{
    program * prog;

    void apply()
    {
        for(auto it = prog->begin();it != prog->end();it++) {
            if (it->op.name() == "convolution") {
                apply_convolution(it);
            } else if (it->op.name() == "activation") {
                apply_activation(it);
            }
        }
    }

    void apply_convolution(instruction_ref ins)
    {
        auto&& op = any_cast<convolution>(ins->op);
        prog->replace_instruction(ins, cpu_convolution{op}, ins->arguments);
    }

    void apply_activation(instruction_ref ins)
    {
        auto&& op = any_cast<activation>(ins->op);
        if(op.mode == "relu")
            prog->replace_instruction(ins, relu{}, ins->arguments);
    }

};

std::string cpu_target::name() const
{
    return "cpu";
}

void cpu_target::apply(program& p) const
{
    cpu_apply{&p}.apply();
}

} // namespace cpu

} // namespace rtg