eliminate_contiguous_test.cpp 4.75 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
6
#include <migraphx/op/sin.hpp>
7
#include <migraphx/op/slice.hpp>
8
9
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
10
11
12
#include <basic_ops.hpp>
#include <test.hpp>

13
void run_pass(migraphx::program& p)
14
{
15
16
    migraphx::run_passes(*p.get_main_module(),
                         {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
17
}
18

Paul's avatar
Paul committed
19
TEST_CASE(standard_op)
20
{
Paul's avatar
Paul committed
21
    migraphx::program p;
22
23
24
25
26
27

    auto* mm = p.get_main_module();
    auto l   = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    mm->add_instruction(pass_standard_op{}, c);
28
    auto count = std::distance(p.begin(), p.end());
29
    run_pass(p);
30
31
32
    EXPECT(std::distance(p.begin(), p.end()) == count);
}

Paul's avatar
Paul committed
33
TEST_CASE(standard_op_const)
34
{
Paul's avatar
Paul committed
35
    migraphx::program p;
36
37
38
39
40
41

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    mm->add_instruction(pass_standard_op{}, c);
42
    run_pass(p);
Paul's avatar
Paul committed
43
44
45
46
47
48
    EXPECT(std::distance(p.begin(), p.end()) == 2);
}

TEST_CASE(non_standard_op)
{
    migraphx::program p;
49
50
51
52
53
54

    auto* mm = p.get_main_module();
    auto l   = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    mm->add_instruction(pass_op{}, c);
55
    auto count = std::distance(p.begin(), p.end());
56
    run_pass(p);
57
    EXPECT(std::distance(p.begin(), p.end()) == count);
58
59
}

Paul's avatar
Paul committed
60
61
62
TEST_CASE(non_standard_op_const)
{
    migraphx::program p;
63
64
65
66
67
68

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    mm->add_instruction(pass_op{}, c);
69
    run_pass(p);
Paul's avatar
Paul committed
70
71
72
    EXPECT(std::distance(p.begin(), p.end()) == 2);
}

73
74
75
TEST_CASE(transpose_gemm)
{
    migraphx::program p;
76
77
78
79
80
81
82

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto ic  = mm->add_instruction(migraphx::op::identity{}, c);
    mm->add_instruction(migraphx::op::dot{}, ic, l);
83
    auto count = std::distance(p.begin(), p.end());
84
    run_pass(p);
85
86
87
    EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}

88
89
90
TEST_CASE(transpose_standard_op)
{
    migraphx::program p;
91
92
93
94
95
96
97

    auto* mm = p.get_main_module();
    auto l   = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2}});
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto sn  = mm->add_instruction(migraphx::op::sin{}, c);
    mm->add_instruction(pass_standard_op{}, sn);
98
    auto count = std::distance(p.begin(), p.end());
99
    run_pass(p);
100
101
102
    EXPECT(std::distance(p.begin(), p.end()) == count);
}

Paul's avatar
Paul committed
103
104
105
TEST_CASE(transpose_standard_op_const)
{
    migraphx::program p;
106
107
108
109
110
111
112

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto sn  = mm->add_instruction(migraphx::op::sin{}, c);
    mm->add_instruction(pass_standard_op{}, sn);
113
    run_pass(p);
Paul's avatar
Paul committed
114
115
116
    EXPECT(std::distance(p.begin(), p.end()) == 3);
}

117
118
119
TEST_CASE(no_packed_unary_op)
{
    migraphx::program p;
120
121
122
123
124
125
126

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t   = mm->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto sn  = mm->add_instruction(migraphx::op::sin{}, c);
    mm->add_instruction(pass_standard_op{}, sn);
127
    auto count = std::distance(p.begin(), p.end());
128
    run_pass(p);
129
130
131
    EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}

132
133
134
TEST_CASE(non_standard_return_input)
{
    migraphx::program p;
135
136
137
138
139
140

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto tl  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c   = mm->add_instruction(migraphx::op::contiguous{}, tl);
    mm->add_return({c});
141
142
143
144
145
    auto count = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(std::distance(p.begin(), p.end()) == count);
}

Paul's avatar
Paul committed
146
int main(int argc, const char* argv[]) { test::run(argc, argv); }