simplify_reshapes_test.cpp 5.96 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

struct simplify_reshapes_target
{
    std::string name() const { return "simplify_reshapes"; }
Paul's avatar
Paul committed
10
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
11
    {
Paul's avatar
Paul committed
12
        return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
13
    }
Paul's avatar
Paul committed
14
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
15
16
};

Paul's avatar
Paul committed
17
TEST_CASE(double_contig)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p;
Paul's avatar
Paul committed
20
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
21
22
23
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
24
25
26
27
28
29
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
Paul's avatar
Paul committed
30
    EXPECT(std::distance(p.begin(), p.end()) == 4);
Paul's avatar
Paul committed
31
    auto result = p.eval({});
Paul's avatar
Paul committed
32
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
33
34
}

Paul's avatar
Paul committed
35
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
36
{
Paul's avatar
Paul committed
37
    migraphx::program p;
Paul's avatar
Paul committed
38
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
39
40
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
    p.add_instruction(pass_op{}, t2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
52
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
53
{
Paul's avatar
Paul committed
54
    migraphx::program p;
Paul's avatar
Paul committed
55
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
56
57
58
59
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67
68
69
70
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
71
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
72
{
Paul's avatar
Paul committed
73
    migraphx::program p;
Paul's avatar
Paul committed
74
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
75
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
76
77
78
79
80
81
82
83
84
85
86
    p.add_instruction(pass_op{}, t1);
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 3);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
87
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
    migraphx::program p;
Paul's avatar
Paul committed
90
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
91
92
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
101
102
103
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
104
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    migraphx::program p;
Paul's avatar
Paul committed
107
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
108
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
109
110
111
112
113
114
115
116
117
118
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
119
120
121
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
122
123
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
124
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
125
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
126
127
128
129
130
131
132
133
134
135
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
    EXPECT(p.get_shape() == s);
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == s);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c1);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
145
    auto n         = std::distance(p.begin(), p.end());
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
    p.add_instruction(pass_op{}, c2);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
161
    auto n         = std::distance(p.begin(), p.end());
Paul's avatar
Paul committed
162
163
164
165
166
167
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

Paul's avatar
Paul committed
168
int main(int argc, const char* argv[]) { test::run(argc, argv); }