eliminate_concat_test.cpp 9.46 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
4
5
6
7
8
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
9
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
10
    migraphx::op::concat op;
11
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
12
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
13
    {
wsttiger's avatar
wsttiger committed
14
        return op.compute_shape(std::move(inputs));
15
    }
Paul's avatar
Paul committed
16
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
17
18
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
19
20
21
22
23
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
24
struct concat_test_optimization
25
26
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
27
    std::string name() const { return "eliminate_concat::concat"; }
28
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
29
    std::string allocate() const { return "allocate"; }
30
    /// Return the lowered concat operator
Paul's avatar
Paul committed
31
    migraphx::op::concat get_concat(const migraphx::operation& op) const
32
    {
Paul's avatar
Paul committed
33
        return migraphx::any_cast<concat>(op).op;
34
35
36
37
38
39
40
    }
};

struct eliminate_concat_target
{
    std::size_t align = 32;
    std::string name() const { return "eliminate_target"; }
Paul's avatar
Paul committed
41
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
42
    {
Paul's avatar
Paul committed
43
44
        return {migraphx::eliminate_concat{concat_test_optimization{}},
                migraphx::dead_code_elimination{}};
45
    }
Paul's avatar
Paul committed
46
    migraphx::context get_context() const { return {}; }
47
48
49
50
};

struct allocate
{
Paul's avatar
Paul committed
51
    migraphx::shape s{};
52
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
53
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
54
    {
Paul's avatar
Paul committed
55
        migraphx::check_shapes{inputs}.has(0);
56
57
        return s;
    }
Paul's avatar
Paul committed
58
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
59
60
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
61
62
63
64
65
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
66
struct simple_op
67
{
Paul's avatar
Paul committed
68
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
69
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
70
    {
Paul's avatar
Paul committed
71
        migraphx::check_shapes{inputs}.has(1);
72
73
        return inputs.at(0);
    }
Paul's avatar
Paul committed
74
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
75
76
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
77
78
79
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
80
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
81
82
};

Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
template<class... Ts>
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

using load = migraphx::op::load;
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1 =
            p.add_instruction(allocate{create_shape(1)});
        auto p1 = p.add_instruction(simple_op{}, a1);
        auto a2 =
            p.add_instruction(allocate{create_shape(1)});
        auto p2 = p.add_instruction(simple_op{}, a2);
        std::size_t axis = 0;
        auto a3          = p.add_instruction(
            allocate{create_shape(2)});
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(
            allocate{create_shape(2)});
        auto l1 =
            p.add_instruction(load{create_shape(1), 0}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 =
            p.add_instruction(load{create_shape(1), 4}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
        auto a1 =
            p.add_instruction(allocate{create_shape(1)});
        auto p1 = p.add_instruction(simple_op{}, a1);
        auto a2 =
            p.add_instruction(allocate{create_shape(1)});
        auto p2 = p.add_instruction(simple_op{}, a2);
        std::size_t axis = 0;
        auto a3          = p.add_instruction(
            allocate{create_shape(2)});
        return p.add_instruction(concat(axis), p1, p2, a3);
    };
    auto create_test_program = [&] {
        migraphx::program p;
        auto concat1 = concat_test_program(p);
        auto concat2 = concat_test_program(p);
        std::size_t axis = 0;
        auto a1          = p.add_instruction(
            allocate{create_shape(4)});
        p.add_instruction(concat(axis), concat1, concat2, a1);
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
        auto l1 =
            p.add_instruction(load{create_shape(1), 0}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 =
            p.add_instruction(load{create_shape(1), 4}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        return p.add_instruction(identity{}, a1, p1, p2);
    };
    auto create_control_program = [&] {
        migraphx::program p;
        auto a1          = p.add_instruction(
            allocate{create_shape(4)});
        auto l1 =
            p.add_instruction(load{create_shape(2), 0}, a1);
        auto concat1 = concat_control_program(p, l1);
        auto l2 =
            p.add_instruction(load{create_shape(2), 8}, a1);
        auto concat2 = concat_control_program(p, l2);
        p.add_instruction(identity{}, a1, concat1, concat2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
183
TEST_CASE(basic)
184
{
Paul's avatar
Paul committed
185
    auto create_test_program = [] {
Paul's avatar
Paul committed
186
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
187
        auto a1 =
Paul's avatar
Paul committed
188
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
Paul's avatar
Paul committed
189
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
190
        auto a2 =
Paul's avatar
Paul committed
191
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
Paul's avatar
Paul committed
192
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
193
        auto a3 =
Paul's avatar
Paul committed
194
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Paul's avatar
Paul committed
195
        auto p3          = p.add_instruction(simple_op{}, a3);
196
        std::size_t axis = 1;
Paul's avatar
Paul committed
197
198
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
199
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
200
        return p;
201
    };
Paul's avatar
Paul committed
202
    auto create_control_program = [] {
Paul's avatar
Paul committed
203
        migraphx::program p;
Paul's avatar
Paul committed
204
205
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
206
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
207
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0},
Paul's avatar
Paul committed
208
            {a1});
Paul's avatar
Paul committed
209
        auto p1 = p.add_instruction(simple_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
210
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
211
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512},
Paul's avatar
Paul committed
212
            {a1});
Paul's avatar
Paul committed
213
        auto p2 = p.add_instruction(simple_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
214
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
215
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280},
wsttiger's avatar
wsttiger committed
216
            {a1});
Paul's avatar
Paul committed
217
218
        auto p3 = p.add_instruction(simple_op{}, l3);
        p.add_instruction(identity{}, {a1, p1, p2, p3});
219
220
221
222
223
224
225
226
227
228
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
229
TEST_CASE(wont_work)
230
{
Paul's avatar
Paul committed
231
    auto create_test_program = [] {
Paul's avatar
Paul committed
232
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
233
        auto a1 =
Paul's avatar
Paul committed
234
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
235
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
236
        auto a2 =
Paul's avatar
Paul committed
237
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
238
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
239
        auto a3 =
Paul's avatar
Paul committed
240
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
241
        auto p3          = p.add_instruction(simple_op{}, a3);
242
        std::size_t axis = 1;
Paul's avatar
Paul committed
243
244
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
245
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
246
        return p;
247
    };
Paul's avatar
Paul committed
248
    auto create_control_program = [] {
Paul's avatar
Paul committed
249
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
250
        auto a1 =
Paul's avatar
Paul committed
251
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
252
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
253
        auto a2 =
Paul's avatar
Paul committed
254
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
255
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
256
        auto a3 =
Paul's avatar
Paul committed
257
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
258
        auto p3          = p.add_instruction(simple_op{}, a3);
259
        std::size_t axis = 1;
Paul's avatar
Paul committed
260
261
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
262
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
263
        return p;
264
265
266
267
268
269
270
271
272
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
273
int main(int argc, const char* argv[]) { test::run(argc, argv); }