parse_onehot.cpp 2.68 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
6
#include <migraphx/tune_axis.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
7
8
9
10
11
12
13
14
15

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_onehot : op_parser<parse_onehot>
{
    std::vector<op_desc> operators() const { return {{"OneHot"}}; }

16
    instruction_ref parse(const op_desc& opd,
Paul Fultz II's avatar
Paul Fultz II committed
17
18
19
20
21
22
                          const onnx_parser& /*parser*/,
                          onnx_parser::node_info info,
                          std::vector<instruction_ref> args) const
    {
        migraphx::argument depth_arg = args[1]->eval();
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
23
        int depth = depth_arg.at<int>();
Paul Fultz II's avatar
Paul Fultz II committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

        int64_t axis = -1;
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }

        std::vector<float> depth_input(depth * depth, 0.0f);
        for(int i = 0; i < depth; i++)
        {
            depth_input[depth * i + i] = 1.0f;
        }

        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
Shucai Xiao's avatar
Shucai Xiao committed
39
        auto l_val      = info.add_literal({s, depth_input});
Paul Fultz II's avatar
Paul Fultz II committed
40
41
42
        auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});

        // Finally, we need a transpose to move the inner most dim to the axis dim
43
44
        int n_rank         = gather_out->get_shape().lens().size();
        int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
Paul Fultz II's avatar
Paul Fultz II committed
45
46
47
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
48
49
50
        auto tr_out =
            info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
        auto lens = tr_out->get_shape().lens();
Paul Fultz II's avatar
Paul Fultz II committed
51
52
53
54
55
56
57

        auto off_val = info.add_instruction(
            make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto on_val = info.add_instruction(
            make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
        auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
        auto unsq_off_val =
58
            info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
Paul Fultz II's avatar
Paul Fultz II committed
59
        auto unsq_diff_val =
60
            info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
Paul Fultz II's avatar
Paul Fultz II committed
61
62
63
64
65
66
67
68
        auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
        return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx