pooling.cpp 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/pooling.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {

struct max_pool
{
    static std::string name() { return "max"; }
    template <class T>
    static T start()
    {
        return std::numeric_limits<T>::lowest();
    }

    static double apply(double x, double y)
    {
        double m = std::max(x, y);
        return (m);
    }

    static double final(double x, std::size_t) { return (x); }
};

struct avg_pool
{
    static std::string name() { return "average"; }

    template <class T>
    static double start()
    {
        return 0.0;
    }

    static double apply(double x, double y) { return x + y; }

    static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
};

template <class Op>
struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
{
    cpu_pooling() = default;

    cpu_pooling(op::pooling pop) : op(std::move(pop)) {}

    op::pooling op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::pooling_" + Op::name(); }
    shape compute_shape(std::vector<shape> inputs) const
    {
        inputs.pop_back();
kahmed10's avatar
kahmed10 committed
66
        return op.normalize_compute_shape(inputs);
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    }

    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }

    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
        visit_all(args.back(), args[0])([&](auto output, auto input) {
            using type   = typename decltype(output)::value_type;
            auto in_s    = input.get_shape();
            auto in_lens = in_s.lens();
            std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto n_dim = idx_o.size();
                std::vector<std::size_t> win_start;
                std::vector<std::size_t> win_size;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2  = dim - 2;
                    int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
                                static_cast<int>(op.padding[d_2]);
                    int end = std::min(start + op.lengths[d_2], in_lens[dim]);
                    start   = std::max(start, 0);
                    win_start.push_back(start);
                    win_size.push_back(end - start);
                }

                shape win_shape{output_shape.type(), win_size};
                auto pool_size = win_shape.elements();
                double acc     = Op::template start<type>();
                shape_for_each(win_shape, [&](auto idx_w) {
                    auto idx = idx_o;
                    std::transform(idx_w.begin(),
                                   idx_w.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](auto ii, auto jj) { return ii + jj; });
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       idx < in_lens)
                    {
                        acc = Op::apply(acc, input[in_s.index(idx)]);
                    }
                });

                output[i] = type(Op::final(acc, pool_size));
            });
        });

        return args.back();
    }
};

template struct cpu_pooling<avg_pool>;
template struct cpu_pooling<max_pool>;

struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
128
    std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
129
130
131

    dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
    {
132
133
        auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
                                                     : dnnl::algorithm::pooling_avg;
kahmed10's avatar
kahmed10 committed
134
135
136
        auto kdims = op.kdims();
        std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
        std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
137
138
        return {dnnl::prop_kind::forward_inference,
                algo,
139
140
                m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
                m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
141
142
                to_dnnl_dims(op.stride),
                to_dnnl_dims(op.lengths),
kahmed10's avatar
kahmed10 committed
143
144
                to_dnnl_dims(padding_l),
                to_dnnl_dims(padding_r)};
145
146
147
148
    }
};

} // namespace cpu
149

150
151
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx