pooling.cpp 4.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/pooling.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {

struct max_pool
{
    static std::string name() { return "max"; }
    template <class T>
    static T start()
    {
        return std::numeric_limits<T>::lowest();
    }

    static double apply(double x, double y)
    {
        double m = std::max(x, y);
        return (m);
    }

    static double final(double x, std::size_t) { return (x); }
};

struct avg_pool
{
    static std::string name() { return "average"; }

    template <class T>
    static double start()
    {
        return 0.0;
    }

    static double apply(double x, double y) { return x + y; }

    static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
};

template <class Op>
struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
{
    cpu_pooling() = default;

    cpu_pooling(op::pooling pop) : op(std::move(pop)) {}

    op::pooling op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::pooling_" + Op::name(); }
    shape compute_shape(std::vector<shape> inputs) const
    {
        inputs.pop_back();
kahmed10's avatar
kahmed10 committed
66
        return op.normalize_compute_shape(inputs);
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    }

    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }

    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
        visit_all(args.back(), args[0])([&](auto output, auto input) {
            using type   = typename decltype(output)::value_type;
            auto in_s    = input.get_shape();
            auto in_lens = in_s.lens();
            std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto n_dim = idx_o.size();
                std::vector<std::size_t> win_start;
                std::vector<std::size_t> win_size;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2  = dim - 2;
                    int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
                                static_cast<int>(op.padding[d_2]);
                    int end = std::min(start + op.lengths[d_2], in_lens[dim]);
                    start   = std::max(start, 0);
                    win_start.push_back(start);
                    win_size.push_back(end - start);
                }

                shape win_shape{output_shape.type(), win_size};
                auto pool_size = win_shape.elements();
                double acc     = Op::template start<type>();
                shape_for_each(win_shape, [&](auto idx_w) {
                    auto idx = idx_o;
                    std::transform(idx_w.begin(),
                                   idx_w.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](auto ii, auto jj) { return ii + jj; });
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       idx < in_lens)
                    {
                        acc = Op::apply(acc, input[in_s.index(idx)]);
                    }
                });

                output[i] = type(Op::final(acc, pool_size));
            });
        });

        return args.back();
    }
};

template struct cpu_pooling<avg_pool>;
template struct cpu_pooling<max_pool>;

struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
    std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC}; }

    dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
kahmed10's avatar
kahmed10 committed
132
133
134
135
        auto algo  = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
        auto kdims = op.kdims();
        std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
        std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
136
137
138
139
140
141
        return {dnnl::prop_kind::forward_inference,
                algo,
                m.at(DNNL_ARG_SRC),
                m.at(DNNL_ARG_DST),
                to_dnnl_dims(op.stride),
                to_dnnl_dims(op.lengths),
kahmed10's avatar
kahmed10 committed
142
143
                to_dnnl_dims(padding_l),
                to_dnnl_dims(padding_r)};
144
145
146
147
148
149
    }
};

} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx