decompose_test.cpp 7.12 KB
Newer Older
1
2
3
#include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
4
5
#include <migraphx/make_op.hpp>

6
7
#include <test.hpp>

Paul Fultz II's avatar
Paul Fultz II committed
8
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
9

10
TEST_CASE(dot_add)
11
{
Paul Fultz II's avatar
Paul Fultz II committed
12
    migraphx::module m1;
13
    {
14
15
16
17
18
        auto x   = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto y   = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto z   = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), dot);
19
    }
Paul Fultz II's avatar
Paul Fultz II committed
20
21
    run_pass(m1);
    migraphx::module m2;
22
    {
23
24
25
26
27
        auto x   = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto y   = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto z   = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
        auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
Paul Fultz II's avatar
Paul Fultz II committed
28
        m2.add_instruction(migraphx::make_op("identity"), add);
29
    }
Paul Fultz II's avatar
Paul Fultz II committed
30
    EXPECT(m1 == m2);
31
32
}

33
TEST_CASE(dot_add_beta_float)
34
{
Paul Fultz II's avatar
Paul Fultz II committed
35
    migraphx::module m1;
36
    {
37
38
39
40
41
42
        auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto dot =
            m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), dot);
43
    }
Paul Fultz II's avatar
Paul Fultz II committed
44
45
    run_pass(m1);
    migraphx::module m2;
46
    {
47
48
49
50
        auto x   = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto y   = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto z   = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
        auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
Paul Fultz II's avatar
Paul Fultz II committed
51
        auto beta =
52
            m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
53
54
        auto beta_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
Paul Fultz II's avatar
Paul Fultz II committed
55
        auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
        m2.add_instruction(migraphx::make_op("identity"), add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(dot_add_beta_half)
{
    migraphx::module m1;
    {
        auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto dot =
            m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), dot);
    }
    run_pass(m1);
    migraphx::module m2;
    {
        auto x   = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto y   = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto z   = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
        auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
        auto beta =
            m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
        auto beta_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
        auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
        auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
        m2.add_instruction(migraphx::make_op("identity"), add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(dot_add_beta_double)
{
    migraphx::module m1;
    {
        auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto dot =
            m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), dot);
    }
    run_pass(m1);
    migraphx::module m2;
    {
        auto x   = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto y   = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto z   = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
        auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
        auto beta =
            m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
        auto beta_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
        auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
        auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
        m2.add_instruction(migraphx::make_op("identity"), add);
    }
    EXPECT(m1 == m2);
}

TEST_CASE(dot_add_beta_int)
{
    migraphx::module m1;
    {
        auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto dot =
            m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
        m1.add_instruction(migraphx::make_op("identity"), dot);
    }
    run_pass(m1);
    migraphx::module m2;
    {
        auto x   = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto y   = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto z   = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
        auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
        auto beta =
            m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
        auto beta_broadcast =
            m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
        auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
        auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
144
145
        m2.add_instruction(migraphx::make_op("identity"), add);
    }
Paul Fultz II's avatar
Paul Fultz II committed
146
    EXPECT(m1 == m2);
147
148
149
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }