simplify_algebra_test.cpp 7.34 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

struct simplify_algebra_target
{
    std::string name() const { return "simplify_algebra"; }
Paul's avatar
Paul committed
10
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
11
    {
Paul's avatar
Paul committed
12
        return {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
13
    }
Paul's avatar
Paul committed
14
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
15
16
};

Paul's avatar
Paul committed
17
TEST_CASE(simplify_add1)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p1;
Paul's avatar
Paul committed
20
    {
Paul's avatar
Paul committed
21
22
        auto x    = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
23
24
        auto one  = p1.add_literal(1);
        auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
25
26
27
        auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
        auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
        auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
Paul's avatar
Paul committed
28
29
30
31
        p1.add_instruction(pass_op{}, sum3);
    }
    p1.compile(simplify_algebra_target{});

Paul's avatar
Paul committed
32
    migraphx::program p2;
Paul's avatar
Paul committed
33
    {
Paul's avatar
Paul committed
34
35
        auto x    = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
36
37
        auto one  = p2.add_literal(1);
        auto two  = p2.add_literal(2);
Paul's avatar
Paul committed
38
39
40
        auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
Paul's avatar
Paul committed
41
42
43
44
45
        p2.add_instruction(pass_op{}, sum3);
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
46
TEST_CASE(simplify_add2)
Paul's avatar
Paul committed
47
{
Paul's avatar
Paul committed
48
    migraphx::program p1;
Paul's avatar
Paul committed
49
    {
Paul's avatar
Paul committed
50
51
        auto x    = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
52
53
        auto one  = p1.add_literal(1);
        auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
54
55
56
        auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = p1.add_instruction(migraphx::op::add{}, two, y);
        auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
Paul's avatar
Paul committed
57
58
59
60
        p1.add_instruction(pass_op{}, sum3);
    }
    p1.compile(simplify_algebra_target{});

Paul's avatar
Paul committed
61
    migraphx::program p2;
Paul's avatar
Paul committed
62
    {
Paul's avatar
Paul committed
63
64
        auto x    = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
65
66
        auto one  = p2.add_literal(1);
        auto two  = p2.add_literal(2);
Paul's avatar
Paul committed
67
68
69
        auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
Paul's avatar
Paul committed
70
71
72
73
74
        p2.add_instruction(pass_op{}, sum3);
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
75
TEST_CASE(simplify_add3)
Paul's avatar
Paul committed
76
{
Paul's avatar
Paul committed
77
    migraphx::program p1;
Paul's avatar
Paul committed
78
    {
Paul's avatar
Paul committed
79
        auto x    = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
80
81
        auto one  = p1.add_literal(1);
        auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
82
83
84
        auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two);
        auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
Paul's avatar
Paul committed
85
86
87
88
        p1.add_instruction(pass_op{}, sum3);
    }
    p1.compile(simplify_algebra_target{});

Paul's avatar
Paul committed
89
    migraphx::program p2;
Paul's avatar
Paul committed
90
    {
Paul's avatar
Paul committed
91
        auto x    = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
92
93
        auto one  = p2.add_literal(1);
        auto two  = p2.add_literal(2);
Paul's avatar
Paul committed
94
95
96
        auto sum1 = p2.add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = p2.add_instruction(migraphx::op::add{}, one, two);
        auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
Paul's avatar
Paul committed
97
98
99
100
101
        p2.add_instruction(pass_op{}, sum3);
    }
    EXPECT(p1 == p2);
}

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
TEST_CASE(simplify_add_broadcast1)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    migraphx::program p1;
    {
        auto x    = p1.add_parameter("x", outer);
        auto y    = p1.add_parameter("y", outer);
        auto one  = p1.add_literal({inner, {1, 1}});
        auto oneb = p1.add_instruction(b, one);
        auto two  = p1.add_literal({inner, {2, 2}});
        auto twob = p1.add_instruction(b, two);
        auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
        auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
        auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
        p1.add_instruction(pass_op{}, sum3);
    }
    p1.compile(simplify_algebra_target{});

    migraphx::program p2;
    {
Paul's avatar
Paul committed
124
125
126
127
128
        auto x     = p2.add_parameter("x", outer);
        auto y     = p2.add_parameter("y", outer);
        auto one   = p2.add_literal({inner, {1, 1}});
        auto two   = p2.add_literal({inner, {2, 2}});
        auto sum1  = p2.add_instruction(migraphx::op::add{}, one, two);
129
        auto sum1b = p2.add_instruction(b, sum1);
Paul's avatar
Paul committed
130
131
        auto sum2  = p2.add_instruction(migraphx::op::add{}, x, y);
        auto sum3  = p2.add_instruction(migraphx::op::add{}, sum2, sum1b);
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        p2.add_instruction(pass_op{}, sum3);
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_add_broadcast2)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    auto create_program = [&] {
        migraphx::program p;
        auto x    = p.add_parameter("x", outer);
        auto y    = p.add_parameter("y", outer);
        auto one  = p.add_literal({inner, {1, 1}});
        auto oneb = p.add_instruction(b, one);
        auto two  = p.add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
        auto sum1 = p.add_instruction(migraphx::op::add{}, x, oneb);
        auto sum2 = p.add_instruction(migraphx::op::add{}, y, two);
        auto sum3 = p.add_instruction(migraphx::op::add{}, sum1, sum2);
        p.add_instruction(pass_op{}, sum3);
        return p;
    };
    migraphx::program p1 = create_program();
    p1.compile(simplify_algebra_target{});

    migraphx::program p2 = create_program();
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
162
// TODO: Add test case
163
// TEST_CASE(simplify_add4)
Paul's avatar
Paul committed
164
165
void simplify_add4()
{
Paul's avatar
Paul committed
166
    migraphx::program p1;
Paul's avatar
Paul committed
167
    {
Paul's avatar
Paul committed
168
169
        auto x    = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
170
171
        auto one  = p1.add_literal(1);
        auto two  = p1.add_literal(2);
Paul's avatar
Paul committed
172
173
174
        auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, y);
        auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two);
Paul's avatar
Paul committed
175
176
177
178
        p1.add_instruction(pass_op{}, sum3);
    }
    p1.compile(simplify_algebra_target{});

Paul's avatar
Paul committed
179
    migraphx::program p2;
Paul's avatar
Paul committed
180
    {
Paul's avatar
Paul committed
181
182
        auto x    = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
Paul's avatar
Paul committed
183
184
        auto one  = p2.add_literal(1);
        auto two  = p2.add_literal(2);
Paul's avatar
Paul committed
185
186
187
        auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
Paul's avatar
Paul committed
188
189
190
191
192
        p2.add_instruction(pass_op{}, sum3);
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
193
int main(int argc, const char* argv[]) { test::run(argc, argv); }