eliminate_contiguous_test.cpp 3.89 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
6
#include <migraphx/op/sin.hpp>
7
#include <migraphx/op/slice.hpp>
8
9
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
10
11
12
#include <basic_ops.hpp>
#include <test.hpp>

13
void run_pass(migraphx::program& p)
14
{
15
16
    migraphx::run_passes(p, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}});
}
17

Paul's avatar
Paul committed
18
TEST_CASE(standard_op)
19
{
Paul's avatar
Paul committed
20
    migraphx::program p;
Paul's avatar
Paul committed
21
    auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
Paul's avatar
Paul committed
22
23
    auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c = p.add_instruction(migraphx::op::contiguous{}, t);
24
25
    p.add_instruction(pass_standard_op{}, c);
    auto count = std::distance(p.begin(), p.end());
26
    run_pass(p);
27
28
29
    EXPECT(std::distance(p.begin(), p.end()) == count);
}

Paul's avatar
Paul committed
30
TEST_CASE(standard_op_const)
31
{
Paul's avatar
Paul committed
32
    migraphx::program p;
33
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
34
35
    auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c = p.add_instruction(migraphx::op::contiguous{}, t);
Paul's avatar
Paul committed
36
    p.add_instruction(pass_standard_op{}, c);
37
    run_pass(p);
Paul's avatar
Paul committed
38
39
40
41
42
43
44
45
46
    EXPECT(std::distance(p.begin(), p.end()) == 2);
}

TEST_CASE(non_standard_op)
{
    migraphx::program p;
    auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
    auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c = p.add_instruction(migraphx::op::contiguous{}, t);
47
48
    p.add_instruction(pass_op{}, c);
    auto count = std::distance(p.begin(), p.end());
49
    run_pass(p);
50
    EXPECT(std::distance(p.begin(), p.end()) == count);
51
52
}

Paul's avatar
Paul committed
53
54
55
56
57
58
59
TEST_CASE(non_standard_op_const)
{
    migraphx::program p;
    auto l = p.add_literal(get_2x2());
    auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c);
60
    run_pass(p);
Paul's avatar
Paul committed
61
62
63
    EXPECT(std::distance(p.begin(), p.end()) == 2);
}

64
65
66
TEST_CASE(transpose_gemm)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
67
68
69
    auto l  = p.add_literal(get_2x2());
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c  = p.add_instruction(migraphx::op::contiguous{}, t);
70
71
72
    auto ic = p.add_instruction(migraphx::op::identity{}, c);
    p.add_instruction(migraphx::op::dot{}, ic, l);
    auto count = std::distance(p.begin(), p.end());
73
    run_pass(p);
74
75
76
    EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}

77
78
79
TEST_CASE(transpose_standard_op)
{
    migraphx::program p;
Paul's avatar
Paul committed
80
    auto l  = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
Shucai Xiao's avatar
Shucai Xiao committed
81
82
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c  = p.add_instruction(migraphx::op::contiguous{}, t);
83
84
    auto sn = p.add_instruction(migraphx::op::sin{}, c);
    p.add_instruction(pass_standard_op{}, sn);
85
    auto count = std::distance(p.begin(), p.end());
86
    run_pass(p);
87
88
89
    EXPECT(std::distance(p.begin(), p.end()) == count);
}

Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
TEST_CASE(transpose_standard_op_const)
{
    migraphx::program p;
    auto l  = p.add_literal(get_2x2());
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c  = p.add_instruction(migraphx::op::contiguous{}, t);
    auto sn = p.add_instruction(migraphx::op::sin{}, c);
    p.add_instruction(pass_standard_op{}, sn);
98
    run_pass(p);
Paul's avatar
Paul committed
99
100
101
    EXPECT(std::distance(p.begin(), p.end()) == 3);
}

102
103
104
TEST_CASE(no_packed_unary_op)
{
    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
105
106
107
    auto l  = p.add_literal(get_2x2());
    auto t  = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
    auto c  = p.add_instruction(migraphx::op::contiguous{}, t);
108
109
110
    auto sn = p.add_instruction(migraphx::op::sin{}, c);
    p.add_instruction(pass_standard_op{}, sn);
    auto count = std::distance(p.begin(), p.end());
111
    run_pass(p);
112
113
114
    EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}

Paul's avatar
Paul committed
115
int main(int argc, const char* argv[]) { test::run(argc, argv); }