simplify_reshapes_test.cpp 4.78 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

struct simplify_reshapes_target
{
    std::string name() const { return "simplify_reshapes"; }
Paul's avatar
Paul committed
10
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
11
    {
Paul's avatar
Paul committed
12
        return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
Paul's avatar
Paul committed
13
    }
Paul's avatar
Paul committed
14
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
15
16
};

Paul's avatar
Paul committed
17
TEST_CASE(double_contig)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p;
Paul's avatar
Paul committed
20
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
21
22
23
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
35
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
36
{
Paul's avatar
Paul committed
37
    migraphx::program p;
Paul's avatar
Paul committed
38
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
39
40
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
    p.add_instruction(pass_op{}, t2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
52
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
53
{
Paul's avatar
Paul committed
54
    migraphx::program p;
Paul's avatar
Paul committed
55
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
56
57
58
59
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67
68
69
70
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
71
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
72
{
Paul's avatar
Paul committed
73
    migraphx::program p;
Paul's avatar
Paul committed
74
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
75
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
76
77
78
79
80
81
82
83
84
85
86
    p.add_instruction(pass_op{}, t1);
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 3);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
87
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
    migraphx::program p;
Paul's avatar
Paul committed
90
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
91
92
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
101
102
103
104
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    // std::cout << p << std::endl;
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
105
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
106
{
Paul's avatar
Paul committed
107
    migraphx::program p;
Paul's avatar
Paul committed
108
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
109
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
110
111
112
113
114
115
116
117
118
119
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    p.compile(simplify_reshapes_target{});
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x = p.add_parameter("x", s);
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
    auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
    EXPECT(p.get_shape() == s);
    auto n = std::distance(p.begin(), p.end());
    p.compile(simplify_reshapes_target{});
    EXPECT(p.get_shape() == s);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
137
int main(int argc, const char* argv[]) { test::run(argc, argv); }