simplify_reshapes_test.cpp 5.94 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

struct simplify_reshapes_target
{
    std::string name() const { return "simplify_reshapes"; }
Paul's avatar
Paul committed
10
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
11
    {
Paul's avatar
Paul committed
12
        return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
13
    }
Paul's avatar
Paul committed
14
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
15
16
};

Paul's avatar
Paul committed
17
TEST_CASE(double_contig)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p;
Paul's avatar
Paul committed
20
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
21
22
23
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
24
25
26
27
28
29
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
Paul's avatar
Paul committed
30
    EXPECT(std::distance(p.begin(), p.end()) == 4);
Paul's avatar
Paul committed
31
    auto result = p.eval({});
Paul's avatar
Paul committed
32
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
33
34
}

Paul's avatar
Paul committed
35
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
36
{
Paul's avatar
Paul committed
37
    migraphx::program p;
Paul's avatar
Paul committed
38
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
39
40
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
    p.add_instruction(pass_op{}, t2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
52
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
53
{
Paul's avatar
Paul committed
54
    migraphx::program p;
Paul's avatar
Paul committed
55
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
56
57
58
59
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67
68
69
70
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
71
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
72
{
Paul's avatar
Paul committed
73
    migraphx::program p;
Paul's avatar
Paul committed
74
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
75
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
76
77
78
79
80
81
82
83
84
85
86
    p.add_instruction(pass_op{}, t1);
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 3);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
87
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
    migraphx::program p;
Paul's avatar
Paul committed
90
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
91
92
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
101
102
103
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
104
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    migraphx::program p;
Paul's avatar
Paul committed
107
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
108
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
109
110
111
112
113
114
115
116
117
118
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
119
120
121
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
122
123
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
124
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
125
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
126
127
128
129
130
131
132
133
134
135
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
    EXPECT(p.get_shape() == s);
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == s);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c1);
    auto out_shape = p.get_shape();
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
    p.add_instruction(pass_op{}, c2);
    auto out_shape = p.get_shape();
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

Paul's avatar
Paul committed
168
int main(int argc, const char* argv[]) { test::run(argc, argv); }