eliminate_concat_test.cpp 14.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
7
8
9
10
11
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
12
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
13
    migraphx::op::concat op;
14
15
16
17
18
19
20

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

21
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
22
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
23
    {
24
        inputs.pop_back();
wsttiger's avatar
wsttiger committed
25
        return op.compute_shape(std::move(inputs));
26
    }
Paul's avatar
Paul committed
27
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
28
29
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
30
31
32
33
34
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
35
struct concat_test_optimization
36
37
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
38
    std::string name() const { return "eliminate_concat::concat"; }
39
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
40
    std::string allocate() const { return "allocate"; }
41
    /// Return the lowered concat operator
Paul's avatar
Paul committed
42
    migraphx::op::concat get_concat(const migraphx::operation& op) const
43
    {
Paul's avatar
Paul committed
44
        return migraphx::any_cast<concat>(op).op;
45
46
47
    }
};

48
void run_pass(migraphx::program& p)
49
{
50
    migraphx::run_passes(*p.get_main_module(),
51
52
53
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
54
55
56

struct allocate
{
Paul's avatar
Paul committed
57
    migraphx::shape s{};
58
59
60
61
62
63
64

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

65
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
66
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
67
    {
68
        migraphx::check_shapes{inputs, *this}.has(0);
69
70
        return s;
    }
Paul's avatar
Paul committed
71
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
72
73
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
74
75
76
77
78
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
79
struct simple_op
80
{
Paul's avatar
Paul committed
81
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
82
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
83
    {
84
        migraphx::check_shapes{inputs, *this}.has(1);
85
86
        return inputs.at(0);
    }
Paul's avatar
Paul committed
87
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
88
89
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
90
91
92
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
93
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
94
95
};

Paul's avatar
Paul committed
96
template <class... Ts>
Paul's avatar
Paul committed
97
98
99
100
101
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
102
using load     = migraphx::op::load;
Paul's avatar
Paul committed
103
104
105
106
107
108
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
109
110
111
112
113
114

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
115
        std::size_t axis = 0;
116
117
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
Paul's avatar
Paul committed
118
119
120
121
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
122
123
124
125
126
127
128
129

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(2)});
        auto l1  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
Paul's avatar
Paul committed
130
131
132
133
134
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
135
    run_pass(p1);
Paul's avatar
Paul committed
136
137
138
139

    EXPECT(p1 == p2);
}

140
141
142
143
TEST_CASE(negative_axis1)
{
    auto create_test_program = [] {
        migraphx::program p;
144
145
146
147
148
149

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
150
        std::size_t axis = -1;
151
152
        auto a3          = mm->add_instruction(allocate{create_shape(4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        return p;
    };
    auto create_control_program = create_test_program;

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis2)
{
    auto create_test_program = [] {
        migraphx::program p;
168
169
170
171
172
173

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
174
        std::size_t axis = -2;
175
176
        auto a3          = mm->add_instruction(allocate{create_shape(4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
177
178
179
180
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
181
182
183
184
185
186
187
188

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(4, 2)});
        auto l1  = mm->add_instruction(load{create_shape(2, 2), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(2, 2), 16}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis3)
{
    auto create_test_program = [] {
        migraphx::program p;
203
204
205
206
207
208

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1, 2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1, 2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
209
        std::size_t axis = -2;
210
211
        auto a3          = mm->add_instruction(allocate{create_shape(1, 4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
212
213
214
215
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
216
217
218
219
220
221
222
223

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(1, 4, 2)});
        auto l1  = mm->add_instruction(load{create_shape(1, 2, 2), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1, 2, 2), 16}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
224
225
226
227
228
229
230
231
232
233
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

234
235
236
237
TEST_CASE(reversed)
{
    auto create_test_program = [] {
        migraphx::program p;
238
239
240
241
242
243

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
244
        std::size_t axis = 0;
245
246
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        mm->add_instruction(concat(axis), p2, p1, a3);
247
248
249
250
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
251
252
253
254
255
256
257
258

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(2)});
        auto l1  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p2, p1);
259
260
261
262
263
264
265
266
267
268
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
269
270
271
TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
272
273
274
275
276
        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
277
        std::size_t axis = 0;
278
279
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        return mm->add_instruction(concat(axis), p1, p2, a3);
Paul's avatar
Paul committed
280
281
282
    };
    auto create_test_program = [&] {
        migraphx::program p;
283
        auto* mm         = p.get_main_module();
Paul's avatar
Paul committed
284
285
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
286
        std::size_t axis = 0;
287
288
        auto a1          = mm->add_instruction(allocate{create_shape(4)});
        mm->add_instruction(concat(axis), concat1, concat2, a1);
Paul's avatar
Paul committed
289
290
291
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
292
293
294
295
296
297
        auto* mm = p.get_main_module();
        auto l1  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        return mm->add_instruction(identity{}, a1, p1, p2);
Paul's avatar
Paul committed
298
299
300
    };
    auto create_control_program = [&] {
        migraphx::program p;
301
302
303
        auto* mm     = p.get_main_module();
        auto a1      = mm->add_instruction(allocate{create_shape(4)});
        auto l1      = mm->add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
304
        auto concat1 = concat_control_program(p, l1);
305
        auto l2      = mm->add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
306
        auto concat2 = concat_control_program(p, l2);
307
        mm->add_instruction(identity{}, a1, concat1, concat2);
Paul's avatar
Paul committed
308
309
310
311
312
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
313
    run_pass(p1);
Paul's avatar
Paul committed
314
315
316
317

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
318
TEST_CASE(basic)
319
{
Paul's avatar
Paul committed
320
    auto create_test_program = [] {
Paul's avatar
Paul committed
321
        migraphx::program p;
322
323
324
325
326
327
328
329
330
331
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
332
        std::size_t axis = 1;
333
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
334
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
335
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
336
        return p;
337
    };
Paul's avatar
Paul committed
338
    auto create_control_program = [] {
Paul's avatar
Paul committed
339
        migraphx::program p;
340
341
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
Paul's avatar
Paul committed
342
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
343
        auto l1 = mm->add_instruction(
Paul's avatar
Paul committed
344
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
345
346
        auto p1 = mm->add_instruction(simple_op{}, l1);
        auto l2 = mm->add_instruction(
Paul's avatar
Paul committed
347
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
348
349
        auto p2 = mm->add_instruction(simple_op{}, l2);
        auto l3 = mm->add_instruction(
Paul's avatar
Paul committed
350
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
351
352
        auto p3 = mm->add_instruction(simple_op{}, l3);
        mm->add_instruction(identity{}, {a1, p1, p2, p3});
353
354
355
356
357
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
358
    run_pass(p1);
359
360
361
362

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
363
TEST_CASE(wont_work)
364
{
Paul's avatar
Paul committed
365
    auto create_test_program = [] {
Paul's avatar
Paul committed
366
        migraphx::program p;
367
368
369
370
371
372
373
374
375
376
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
377
        std::size_t axis = 1;
378
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
379
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
380
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
381
        return p;
382
    };
Paul's avatar
Paul committed
383
    auto create_control_program = [] {
Paul's avatar
Paul committed
384
        migraphx::program p;
385
386
387
388
389
390
391
392
393
394
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
395
        std::size_t axis = 1;
396
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
397
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
398
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
399
        return p;
400
401
402
403
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
404
    run_pass(p1);
405
406
407
408

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
409
int main(int argc, const char* argv[]) { test::run(argc, argv); }