simplify_reshapes_test.cpp 7.89 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

struct simplify_reshapes_target
{
    std::string name() const { return "simplify_reshapes"; }
Paul's avatar
Paul committed
10
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
11
    {
Paul's avatar
Paul committed
12
        return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
13
    }
Paul's avatar
Paul committed
14
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
15
16
};

Paul's avatar
Paul committed
17
TEST_CASE(double_contig)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p;
Paul's avatar
Paul committed
20
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
21
22
23
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
24
25
26
27
28
29
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
Paul's avatar
Paul committed
30
    EXPECT(std::distance(p.begin(), p.end()) == 4);
Paul's avatar
Paul committed
31
    auto result = p.eval({});
Paul's avatar
Paul committed
32
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
33
34
}

Paul's avatar
Paul committed
35
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
36
{
Paul's avatar
Paul committed
37
    migraphx::program p;
Paul's avatar
Paul committed
38
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
39
40
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
    p.add_instruction(pass_op{}, t2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
52
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
53
{
Paul's avatar
Paul committed
54
    migraphx::program p;
Paul's avatar
Paul committed
55
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
56
57
58
59
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67
68
69
70
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
71
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
72
{
Paul's avatar
Paul committed
73
    migraphx::program p;
Paul's avatar
Paul committed
74
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
75
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
76
77
78
79
80
81
82
83
84
85
86
    p.add_instruction(pass_op{}, t1);
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 3);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
87
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
    migraphx::program p;
Paul's avatar
Paul committed
90
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
91
92
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
101
102
103
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
104
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    migraphx::program p;
Paul's avatar
Paul committed
107
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
108
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
109
110
111
112
113
114
115
116
117
118
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
119
120
121
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
122
123
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
124
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
125
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
126
127
128
129
130
131
132
133
134
135
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
    EXPECT(p.get_shape() == s);
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == s);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c1);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
145
    auto n         = std::distance(p.begin(), p.end());
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
    p.add_instruction(pass_op{}, c2);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
161
    auto n         = std::distance(p.begin(), p.end());
Paul's avatar
Paul committed
162
163
164
165
166
167
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    p.add_instruction(pass_op{}, t2);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    p.add_instruction(pass_op{}, t3);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    auto t4  = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
    p.add_instruction(pass_op{}, t4);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
216
int main(int argc, const char* argv[]) { test::run(argc, argv); }