"src/targets/vscode:/vscode.git/clone" did not exist on "060ffb390119766b2f872b5ac80e2a24fb556b1e"
onehot_test.cpp 1.64 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

#include <onnx_test.hpp>

TEST_CASE(onehot_test)
{
    migraphx::program p;
    auto* mm = p.get_main_module();
    migraphx::shape s_ind{migraphx::shape::int32_type, {5, 2}};
    migraphx::shape s_val{migraphx::shape::half_type, {2}};
    mm->add_literal(3);
    auto l_ind = mm->add_parameter("indices", s_ind);
    auto l_val = mm->add_parameter("values", s_val);
    migraphx::shape s_dep{migraphx::shape::half_type, {3, 3}};
    std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
    auto l_dep      = mm->add_literal(migraphx::literal(s_dep, data_dep));
    auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind);
    auto tr_out  = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}),
                                      gather_out);
    auto off_val = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val);
    auto on_val = mm->add_instruction(
        migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val);
    auto diff       = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val);
    auto mb_off_val = mm->add_instruction(
        migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), off_val);
    auto mb_diff =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), diff);
    auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff);
    auto r   = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val);
    mm->add_return({r});

    auto prog = migraphx::parse_onnx("onehot_test.onnx");

    EXPECT(p == prog);
}