decompose_test.cpp 2.87 KB
Newer Older
1
2
3
#include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
4
5
#include <migraphx/make_op.hpp>

6
7
#include <test.hpp>

Paul Fultz II's avatar
Paul Fultz II committed
8
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
9

10
TEST_CASE(quant_dot_add)
11
{
Paul Fultz II's avatar
Paul Fultz II committed
12
    migraphx::module m1;
13
    {
14
15
16
17
18
        auto x     = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto y     = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto z     = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto q_dot = m1.add_instruction(migraphx::make_op("quant_dot"), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), q_dot);
19
    }
Paul Fultz II's avatar
Paul Fultz II committed
20
21
    run_pass(m1);
    migraphx::module m2;
22
    {
23
24
25
26
27
28
        auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto q_dot =
            m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
        auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, z);
Paul Fultz II's avatar
Paul Fultz II committed
29
        m2.add_instruction(migraphx::make_op("identity"), add);
30
    }
Paul Fultz II's avatar
Paul Fultz II committed
31
    EXPECT(m1 == m2);
32
33
}

34
TEST_CASE(quant_dot_add_beta)
35
{
Paul Fultz II's avatar
Paul Fultz II committed
36
    migraphx::module m1;
37
    {
38
39
40
41
42
43
        auto x     = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto y     = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto z     = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto q_dot = m1.add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1.0}, {"beta", 2}}), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), q_dot);
44
    }
Paul Fultz II's avatar
Paul Fultz II committed
45
46
    run_pass(m1);
    migraphx::module m2;
47
    {
48
49
50
51
52
        auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
        auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto q_dot =
            m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
Paul Fultz II's avatar
Paul Fultz II committed
53
        auto beta =
54
            m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
55
56
        auto beta_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
Paul Fultz II's avatar
Paul Fultz II committed
57
        auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
58
        auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, mul);
59
60
        m2.add_instruction(migraphx::make_op("identity"), add);
    }
Paul Fultz II's avatar
Paul Fultz II committed
61
    EXPECT(m1 == m2);
62
63
64
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }