eliminate_concat_test.cpp 13.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
7
8
9
10
11
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
12
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
13
    migraphx::op::concat op;
14
15
16
17
18
19
20

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

21
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
22
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
23
    {
24
        inputs.pop_back();
wsttiger's avatar
wsttiger committed
25
        return op.compute_shape(std::move(inputs));
26
    }
Paul's avatar
Paul committed
27
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
28
29
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
30
31
32
33
34
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
35
struct concat_test_optimization
36
37
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
38
    std::string name() const { return "eliminate_concat::concat"; }
39
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
40
    std::string allocate() const { return "allocate"; }
41
    /// Return the lowered concat operator
Paul's avatar
Paul committed
42
    migraphx::op::concat get_concat(const migraphx::operation& op) const
43
    {
Paul's avatar
Paul committed
44
        return migraphx::any_cast<concat>(op).op;
45
46
47
    }
};

48
void run_pass(migraphx::program& p)
49
{
50
51
52
53
    migraphx::run_passes(p,
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
54
55
56

struct allocate
{
Paul's avatar
Paul committed
57
    migraphx::shape s{};
58
59
60
61
62
63
64

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

65
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
66
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
67
    {
Paul's avatar
Paul committed
68
        migraphx::check_shapes{inputs}.has(0);
69
70
        return s;
    }
Paul's avatar
Paul committed
71
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
72
73
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
74
75
76
77
78
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
79
struct simple_op
80
{
Paul's avatar
Paul committed
81
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
82
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
83
    {
Paul's avatar
Paul committed
84
        migraphx::check_shapes{inputs}.has(1);
85
86
        return inputs.at(0);
    }
Paul's avatar
Paul committed
87
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
88
89
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
90
91
92
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
93
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
94
95
};

Paul's avatar
Paul committed
96
template <class... Ts>
Paul's avatar
Paul committed
97
98
99
100
101
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
102
using load     = migraphx::op::load;
Paul's avatar
Paul committed
103
104
105
106
107
108
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
109
110
111
112
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
113
        std::size_t axis = 0;
Paul's avatar
Paul committed
114
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
115
116
117
118
119
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
120
121
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
122
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
123
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
124
125
126
127
128
129
130
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
131
    run_pass(p1);
Paul's avatar
Paul committed
132
133
134
135

    EXPECT(p1 == p2);
}

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
TEST_CASE(negative_axis1)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(allocate{create_shape(2, 2)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(2, 2)});
        auto p2          = p.add_instruction(simple_op{}, a2);
        std::size_t axis = -1;
        auto a3          = p.add_instruction(allocate{create_shape(4, 2)});
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = create_test_program;

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis2)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(allocate{create_shape(2, 2)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(2, 2)});
        auto p2          = p.add_instruction(simple_op{}, a2);
        std::size_t axis = -2;
        auto a3          = p.add_instruction(allocate{create_shape(4, 2)});
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
        auto a1 = p.add_instruction(allocate{create_shape(4, 2)});
        auto l1 = p.add_instruction(load{create_shape(2, 2), 0}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 = p.add_instruction(load{create_shape(2, 2), 16}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis3)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(allocate{create_shape(1, 2, 2)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1, 2, 2)});
        auto p2          = p.add_instruction(simple_op{}, a2);
        std::size_t axis = -2;
        auto a3          = p.add_instruction(allocate{create_shape(1, 4, 2)});
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
        auto a1 = p.add_instruction(allocate{create_shape(1, 4, 2)});
        auto l1 = p.add_instruction(load{create_shape(1, 2, 2), 0}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 = p.add_instruction(load{create_shape(1, 2, 2), 16}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
TEST_CASE(reversed)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
        std::size_t axis = 0;
        auto a3          = p.add_instruction(allocate{create_shape(2)});
        p.add_instruction(concat(axis), p2, p1, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 4}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 = p.add_instruction(load{create_shape(1), 0}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p2, p1);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
251
252
253
TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
Paul's avatar
Paul committed
254
255
256
257
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
258
        std::size_t axis = 0;
Paul's avatar
Paul committed
259
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
260
261
262
263
        return p.add_instruction(concat(axis), p1, p2, a3);
    };
    auto create_test_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
264
265
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
266
        std::size_t axis = 0;
Paul's avatar
Paul committed
267
        auto a1          = p.add_instruction(allocate{create_shape(4)});
Paul's avatar
Paul committed
268
269
270
271
        p.add_instruction(concat(axis), concat1, concat2, a1);
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
Paul's avatar
Paul committed
272
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
273
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
274
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
275
276
277
278
279
        auto p2 = p.add_instruction(simple_op{}, l2);
        return p.add_instruction(identity{}, a1, p1, p2);
    };
    auto create_control_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
280
281
        auto a1      = p.add_instruction(allocate{create_shape(4)});
        auto l1      = p.add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
282
        auto concat1 = concat_control_program(p, l1);
Paul's avatar
Paul committed
283
        auto l2      = p.add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
284
285
286
287
288
289
290
        auto concat2 = concat_control_program(p, l2);
        p.add_instruction(identity{}, a1, concat1, concat2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
291
    run_pass(p1);
Paul's avatar
Paul committed
292
293
294
295

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
296
TEST_CASE(basic)
297
{
Paul's avatar
Paul committed
298
    auto create_test_program = [] {
Paul's avatar
Paul committed
299
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
300
        auto a1 =
Paul's avatar
Paul committed
301
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
Paul's avatar
Paul committed
302
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
303
        auto a2 =
Paul's avatar
Paul committed
304
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
Paul's avatar
Paul committed
305
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
306
        auto a3 =
Paul's avatar
Paul committed
307
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Paul's avatar
Paul committed
308
        auto p3          = p.add_instruction(simple_op{}, a3);
309
        std::size_t axis = 1;
Paul's avatar
Paul committed
310
311
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
312
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
313
        return p;
314
    };
Paul's avatar
Paul committed
315
    auto create_control_program = [] {
Paul's avatar
Paul committed
316
        migraphx::program p;
Paul's avatar
Paul committed
317
318
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
319
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
320
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
Paul's avatar
Paul committed
321
        auto p1 = p.add_instruction(simple_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
322
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
323
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
Paul's avatar
Paul committed
324
        auto p2 = p.add_instruction(simple_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
325
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
326
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
Paul's avatar
Paul committed
327
328
        auto p3 = p.add_instruction(simple_op{}, l3);
        p.add_instruction(identity{}, {a1, p1, p2, p3});
329
330
331
332
333
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
334
    run_pass(p1);
335
336
337
338

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
339
TEST_CASE(wont_work)
340
{
Paul's avatar
Paul committed
341
    auto create_test_program = [] {
Paul's avatar
Paul committed
342
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
343
        auto a1 =
Paul's avatar
Paul committed
344
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
345
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
346
        auto a2 =
Paul's avatar
Paul committed
347
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
348
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
349
        auto a3 =
Paul's avatar
Paul committed
350
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
351
        auto p3          = p.add_instruction(simple_op{}, a3);
352
        std::size_t axis = 1;
Paul's avatar
Paul committed
353
354
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
355
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
356
        return p;
357
    };
Paul's avatar
Paul committed
358
    auto create_control_program = [] {
Paul's avatar
Paul committed
359
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
360
        auto a1 =
Paul's avatar
Paul committed
361
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
362
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
363
        auto a2 =
Paul's avatar
Paul committed
364
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
365
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
366
        auto a3 =
Paul's avatar
Paul committed
367
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
368
        auto p3          = p.add_instruction(simple_op{}, a3);
369
        std::size_t axis = 1;
Paul's avatar
Paul committed
370
371
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
372
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
373
        return p;
374
375
376
377
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
378
    run_pass(p1);
379
380
381
382

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
383
int main(int argc, const char* argv[]) { test::run(argc, argv); }