eliminate_concat_test.cpp 14.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
8
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
9
10
11
12
13
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
14
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
15
    migraphx::op::concat op;
16
17
18
19
20
21
22

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
29
    migraphx::value attributes() const
    {
        migraphx::value normalize;
        normalize["axis"] = migraphx::value::array{migraphx::op::normalize_attribute::include_min};
        return {{"normalize_axes", normalize}};
    }

30
    std::string name() const { return "eliminate_concat::concat"; }
Shucai Xiao's avatar
Shucai Xiao committed
31
    migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const
32
    {
33
        inputs.pop_back();
Shucai Xiao's avatar
Shucai Xiao committed
34
        return op.normalize_compute_shape(std::move(inputs));
35
    }
Paul's avatar
Paul committed
36
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
37
38
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
39
40
41
42
43
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
44
struct concat_test_optimization
45
46
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
47
    std::string name() const { return "eliminate_concat::concat"; }
48
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
49
    std::string allocate() const { return "allocate"; }
50
    /// Return the lowered concat operator
Paul's avatar
Paul committed
51
    migraphx::op::concat get_concat(const migraphx::operation& op) const
52
    {
Paul's avatar
Paul committed
53
        return migraphx::any_cast<concat>(op).op;
54
55
56
    }
};

57
void run_pass(migraphx::program& p)
58
{
59
    migraphx::run_passes(*p.get_main_module(),
60
61
62
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
63
64
65

struct allocate
{
Paul's avatar
Paul committed
66
    migraphx::shape s{};
67
68
69
70
71
72
73

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

74
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
75
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
76
    {
77
        migraphx::check_shapes{inputs, *this}.has(0);
78
79
        return s;
    }
Paul's avatar
Paul committed
80
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
81
82
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
83
84
85
86
87
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
88
struct simple_op
89
{
Paul's avatar
Paul committed
90
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
91
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
92
    {
93
        migraphx::check_shapes{inputs, *this}.has(1);
94
95
        return inputs.at(0);
    }
Paul's avatar
Paul committed
96
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
97
98
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
99
100
101
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
102
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
103
104
};

Paul's avatar
Paul committed
105
template <class... Ts>
Paul's avatar
Paul committed
106
107
108
109
110
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
111
using load     = migraphx::op::load;
Paul's avatar
Paul committed
112
113
114
115
116
117
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
118
119
120
121
122
123

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
124
        std::size_t axis = 0;
125
126
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
Paul's avatar
Paul committed
127
128
129
130
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
131
132
133
134
135
136
137
138

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(2)});
        auto l1  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
Paul's avatar
Paul committed
139
140
141
142
143
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
144
    run_pass(p1);
Paul's avatar
Paul committed
145
146
147
148

    EXPECT(p1 == p2);
}

149
150
151
152
TEST_CASE(negative_axis1)
{
    auto create_test_program = [] {
        migraphx::program p;
153
154
155
156
157
158

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
159
        std::size_t axis = -1;
160
161
        auto a3          = mm->add_instruction(allocate{create_shape(4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        return p;
    };
    auto create_control_program = create_test_program;

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis2)
{
    auto create_test_program = [] {
        migraphx::program p;
177
178
179
180
181
182

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
183
        std::size_t axis = -2;
184
185
        auto a3          = mm->add_instruction(allocate{create_shape(4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
186
187
188
189
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
190
191
192
193
194
195
196
197

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(4, 2)});
        auto l1  = mm->add_instruction(load{create_shape(2, 2), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(2, 2), 16}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

TEST_CASE(negative_axis3)
{
    auto create_test_program = [] {
        migraphx::program p;
212
213
214
215
216
217

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1, 2, 2)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1, 2, 2)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
218
        std::size_t axis = -2;
219
220
        auto a3          = mm->add_instruction(allocate{create_shape(1, 4, 2)});
        mm->add_instruction(concat(axis), p1, p2, a3);
221
222
223
224
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
225
226
227
228
229
230
231
232

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(1, 4, 2)});
        auto l1  = mm->add_instruction(load{create_shape(1, 2, 2), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1, 2, 2), 16}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p1, p2);
233
234
235
236
237
238
239
240
241
242
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

243
244
245
246
TEST_CASE(reversed)
{
    auto create_test_program = [] {
        migraphx::program p;
247
248
249
250
251
252

        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
253
        std::size_t axis = 0;
254
255
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        mm->add_instruction(concat(axis), p2, p1, a3);
256
257
258
259
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
260
261
262
263
264
265
266
267

        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(allocate{create_shape(2)});
        auto l1  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        mm->add_instruction(identity{}, a1, p2, p1);
268
269
270
271
272
273
274
275
276
277
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
278
279
280
TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
281
282
283
284
285
        auto* mm         = p.get_main_module();
        auto a1          = mm->add_instruction(allocate{create_shape(1)});
        auto p1          = mm->add_instruction(simple_op{}, a1);
        auto a2          = mm->add_instruction(allocate{create_shape(1)});
        auto p2          = mm->add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
286
        std::size_t axis = 0;
287
288
        auto a3          = mm->add_instruction(allocate{create_shape(2)});
        return mm->add_instruction(concat(axis), p1, p2, a3);
Paul's avatar
Paul committed
289
290
291
    };
    auto create_test_program = [&] {
        migraphx::program p;
292
        auto* mm         = p.get_main_module();
Paul's avatar
Paul committed
293
294
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
295
        std::size_t axis = 0;
296
297
        auto a1          = mm->add_instruction(allocate{create_shape(4)});
        mm->add_instruction(concat(axis), concat1, concat2, a1);
Paul's avatar
Paul committed
298
299
300
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
301
302
303
304
305
306
        auto* mm = p.get_main_module();
        auto l1  = mm->add_instruction(load{create_shape(1), 0}, a1);
        auto p1  = mm->add_instruction(simple_op{}, l1);
        auto l2  = mm->add_instruction(load{create_shape(1), 4}, a1);
        auto p2  = mm->add_instruction(simple_op{}, l2);
        return mm->add_instruction(identity{}, a1, p1, p2);
Paul's avatar
Paul committed
307
308
309
    };
    auto create_control_program = [&] {
        migraphx::program p;
310
311
312
        auto* mm     = p.get_main_module();
        auto a1      = mm->add_instruction(allocate{create_shape(4)});
        auto l1      = mm->add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
313
        auto concat1 = concat_control_program(p, l1);
314
        auto l2      = mm->add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
315
        auto concat2 = concat_control_program(p, l2);
316
        mm->add_instruction(identity{}, a1, concat1, concat2);
Paul's avatar
Paul committed
317
318
319
320
321
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
322
    run_pass(p1);
Paul's avatar
Paul committed
323
324
325
326

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
327
TEST_CASE(basic)
328
{
Paul's avatar
Paul committed
329
    auto create_test_program = [] {
Paul's avatar
Paul committed
330
        migraphx::program p;
331
332
333
334
335
336
337
338
339
340
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
341
        std::size_t axis = 1;
342
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
343
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
344
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
345
        return p;
346
    };
Paul's avatar
Paul committed
347
    auto create_control_program = [] {
Paul's avatar
Paul committed
348
        migraphx::program p;
349
350
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
Paul's avatar
Paul committed
351
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
352
        auto l1 = mm->add_instruction(
Paul's avatar
Paul committed
353
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
354
355
        auto p1 = mm->add_instruction(simple_op{}, l1);
        auto l2 = mm->add_instruction(
Paul's avatar
Paul committed
356
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
357
358
        auto p2 = mm->add_instruction(simple_op{}, l2);
        auto l3 = mm->add_instruction(
Paul's avatar
Paul committed
359
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
360
361
        auto p3 = mm->add_instruction(simple_op{}, l3);
        mm->add_instruction(identity{}, {a1, p1, p2, p3});
362
363
364
365
366
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
367
    run_pass(p1);
368
369
370
371

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
372
TEST_CASE(wont_work)
373
{
Paul's avatar
Paul committed
374
    auto create_test_program = [] {
Paul's avatar
Paul committed
375
        migraphx::program p;
376
377
378
379
380
381
382
383
384
385
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
386
        std::size_t axis = 1;
387
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
388
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
389
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
390
        return p;
391
    };
Paul's avatar
Paul committed
392
    auto create_control_program = [] {
Paul's avatar
Paul committed
393
        migraphx::program p;
394
395
396
397
398
399
400
401
402
403
        auto* mm = p.get_main_module();
        auto a1  = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto p1 = mm->add_instruction(simple_op{}, a1);
        auto a2 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto p2 = mm->add_instruction(simple_op{}, a2);
        auto a3 = mm->add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = mm->add_instruction(simple_op{}, a3);
404
        std::size_t axis = 1;
405
        auto a4          = mm->add_instruction(
Paul's avatar
Paul committed
406
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
407
        mm->add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
408
        return p;
409
410
411
412
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
413
    run_pass(p1);
414
415
416
417

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
418
int main(int argc, const char* argv[]) { test::run(argc, argv); }