qlinearmul_test.cpp 2.04 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

#include <onnx_test.hpp>


TEST_CASE(qlinearmul_test)
{
    migraphx::program p;
    auto* mm = p.get_main_module();

    auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
    auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});

    auto sc_a   = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
    auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});

    auto sc_b   = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
    auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {16}});

    auto sc_c   = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
    auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {100}});

    auto scale_a_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);

    auto z_pt_a_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);

    auto fp_a =
        mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);

    auto scale_b_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);

    auto z_pt_b_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);

    auto fp_b =
        mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);

    auto fp_c = mm->add_instruction(migraphx::make_op("mul"), fp_a, fp_b);

    auto scale_c_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);

    auto z_pt_c_bcast =
        mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);

    auto c =
        mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);

    mm->add_return({c});

    auto prog = migraphx::parse_onnx("qlinearmul_test.onnx");

    EXPECT(p.sort() == prog.sort());
}