eliminate_concat_test.cpp 6.18 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
4
5
6
7
8
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
9
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
10
    migraphx::op::concat op;
11
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
12
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
13
    {
wsttiger's avatar
wsttiger committed
14
        return op.compute_shape(std::move(inputs));
15
    }
Paul's avatar
Paul committed
16
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
17
18
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
19
20
21
22
23
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
24
struct concat_test_optimization
25
26
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
27
    std::string name() const { return "eliminate_concat::concat"; }
28
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
29
    std::string allocate() const { return "allocate"; }
30
    /// Return the lowered concat operator
Paul's avatar
Paul committed
31
    migraphx::op::concat get_concat(const migraphx::operation& op) const
32
    {
Paul's avatar
Paul committed
33
        return migraphx::any_cast<concat>(op).op;
34
35
36
37
38
39
40
    }
};

struct eliminate_concat_target
{
    std::size_t align = 32;
    std::string name() const { return "eliminate_target"; }
Paul's avatar
Paul committed
41
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
42
    {
Paul's avatar
Paul committed
43
44
        return {migraphx::eliminate_concat{concat_test_optimization{}},
                migraphx::dead_code_elimination{}};
45
    }
Paul's avatar
Paul committed
46
    migraphx::context get_context() const { return {}; }
47
48
49
50
};

struct allocate
{
Paul's avatar
Paul committed
51
    migraphx::shape s{};
52
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
53
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
54
    {
Paul's avatar
Paul committed
55
        migraphx::check_shapes{inputs}.has(0);
56
57
        return s;
    }
Paul's avatar
Paul committed
58
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
59
60
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
61
62
63
64
65
66
67
68
    {
        return {output_shape};
    }
};

struct fred_op
{
    std::string name() const { return "fred_op"; }
Paul's avatar
Paul committed
69
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
70
    {
Paul's avatar
Paul committed
71
        migraphx::check_shapes{inputs}.has(1);
72
73
        return inputs.at(0);
    }
Paul's avatar
Paul committed
74
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
75
76
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
77
78
79
80
81
    {
        return args.at(0);
    }
};

Paul's avatar
Paul committed
82
TEST_CASE(basic)
83
84
{
    auto create_test_program = []() {
Paul's avatar
Paul committed
85
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
86
        auto a1 =
Paul's avatar
Paul committed
87
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
88
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
89
        auto a2 =
Paul's avatar
Paul committed
90
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
91
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
92
        auto a3 =
Paul's avatar
Paul committed
93
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
94
        auto p3          = p.add_instruction(fred_op{}, a3);
95
        std::size_t axis = 1;
Paul's avatar
Paul committed
96
97
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
98
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
99
        return p;
100
101
    };
    auto create_control_program = []() {
Paul's avatar
Paul committed
102
        migraphx::program p;
Paul's avatar
Paul committed
103
104
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
105
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
106
107
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0},
            {a1});
108
        auto p1 = p.add_instruction(fred_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
109
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
110
111
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512},
            {a1});
112
        auto p2 = p.add_instruction(fred_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
113
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
114
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280},
wsttiger's avatar
wsttiger committed
115
            {a1});
116
        auto p3 = p.add_instruction(fred_op{}, l3);
Paul's avatar
Paul committed
117
        p.add_instruction(migraphx::op::identity{}, {a1, p1, p2, p3});
118
119
120
121
122
123
124
125
126
127
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
128
TEST_CASE(wont_work)
129
130
{
    auto create_test_program = []() {
Paul's avatar
Paul committed
131
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
132
        auto a1 =
Paul's avatar
Paul committed
133
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
134
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
135
        auto a2 =
Paul's avatar
Paul committed
136
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
137
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
138
        auto a3 =
Paul's avatar
Paul committed
139
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
140
        auto p3          = p.add_instruction(fred_op{}, a3);
141
        std::size_t axis = 1;
Paul's avatar
Paul committed
142
143
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
144
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
145
        return p;
146
147
    };
    auto create_control_program = []() {
Paul's avatar
Paul committed
148
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
149
        auto a1 =
Paul's avatar
Paul committed
150
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
151
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
152
        auto a2 =
Paul's avatar
Paul committed
153
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
154
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
155
        auto a3 =
Paul's avatar
Paul committed
156
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
157
        auto p3          = p.add_instruction(fred_op{}, a3);
158
        std::size_t axis = 1;
Paul's avatar
Paul committed
159
160
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
161
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
162
        return p;
163
164
165
166
167
168
169
170
171
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
172
int main(int argc, const char* argv[]) { test::run(argc, argv); }