argmax.hpp 2.07 KB
Newer Older
1
2
3
4
5
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP

#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
6
#include <migraphx/par_dfor.hpp>
7
8
9
10
11
12
13
14
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct argmax
{
Shucai Xiao's avatar
Shucai Xiao committed
15
    int axis = 0;
16
17
18
19

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
20
        return pack(f(self.axis, "axis"));
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    }

    std::string name() const { return "argmax"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1).standard();
        auto lens = inputs[0].lens();
        int n_dim = static_cast<int>(lens.size());
        if(axis >= n_dim || axis < 0)
        {
            MIGRAPHX_THROW("ARGMAX: axis is out of range.");
        }

        lens[axis] = 1;

        return {shape::int64_type, lens};
    }
39
40

    template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
41
    int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    {
        auto max_val      = input(indices.begin(), indices.end());
        int64_t max_index = 0;
        for(std::size_t i = 1; i < item_num; ++i)
        {
            indices[axis] = i;
            if(max_val < input(indices.begin(), indices.end()))
            {
                max_val   = input(indices.begin(), indices.end());
                max_index = i;
            }
        }

        return max_index;
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        auto batch_item_num = args.front().get_shape().lens()[axis];

        result.visit([&](auto output) {
            args[0].visit([&](auto input) {
                par_for(output_shape.elements(), [&](auto i) {
                    auto data_idx = output_shape.multi(i);
                    output[i]     = this->calc_argmax(input, data_idx, batch_item_num);
                });
            });
        });

        return result;
    }
74
75
76
77
78
79
80
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif