eliminate_concat_test.cpp 13.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
8
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
9
10
11
12
13
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
14
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
15
    migraphx::op::concat op;
16
17
18
19
20
21
22

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
29
    migraphx::value attributes() const
    {
        migraphx::value normalize;
        normalize["axis"] = migraphx::value::array{migraphx::op::normalize_attribute::include_min};
        return {{"normalize_axes", normalize}};
    }

30
    std::string name() const { return "eliminate_concat::concat"; }
Shucai Xiao's avatar
Shucai Xiao committed
31
    migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const
32
    {
33
        inputs.pop_back();
Shucai Xiao's avatar
Shucai Xiao committed
34
        return op.normalize_compute_shape(std::move(inputs));
35
    }
Paul's avatar
Paul committed
36
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
37
38
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
39
40
41
42
43
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
44
struct concat_test_optimization
45
46
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
47
    std::string name() const { return "eliminate_concat::concat"; }
48
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
49
    std::string allocate() const { return "allocate"; }
50
    /// Return the lowered concat operator
Paul's avatar
Paul committed
51
    migraphx::op::concat get_concat(const migraphx::operation& op) const
52
    {
Paul's avatar
Paul committed
53
        return migraphx::any_cast<concat>(op).op;
54
55
56
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
57
void run_pass(migraphx::module& m)
58
{
Paul Fultz II's avatar
Paul Fultz II committed
59
    migraphx::run_passes(m,
60
61
62
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
63
64
65

struct allocate
{
Paul's avatar
Paul committed
66
    migraphx::shape s{};
67
68
69
70
71
72
73

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

74
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
75
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
76
    {
77
        migraphx::check_shapes{inputs, *this}.has(0);
78
79
        return s;
    }
Paul's avatar
Paul committed
80
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
81
82
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
83
84
85
86
87
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
88
struct simple_op
89
{
Paul's avatar
Paul committed
90
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
91
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
92
    {
93
        migraphx::check_shapes{inputs, *this}.has(1);
94
95
        return inputs.at(0);
    }
Paul's avatar
Paul committed
96
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
97
98
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
99
100
101
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
102
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
103
104
};

Paul's avatar
Paul committed
105
template <class... Ts>
Paul's avatar
Paul committed
106
107
108
109
110
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
111
using load     = migraphx::op::load;
Paul's avatar
Paul committed
112
113
114
115
116
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
117
        migraphx::module m;
118

Paul Fultz II's avatar
Paul Fultz II committed
119
120
121
122
        auto a1          = m.add_instruction(allocate{create_shape(1)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(1)});
        auto m2          = m.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
123
        std::size_t axis = 0;
Paul Fultz II's avatar
Paul Fultz II committed
124
125
126
        auto a3          = m.add_instruction(allocate{create_shape(2)});
        m.add_instruction(concat(axis), m1, m2, a3);
        return m;
Paul's avatar
Paul committed
127
128
    };
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
129
130
131
132
133
134
135
136
137
        migraphx::module m;

        auto a1 = m.add_instruction(allocate{create_shape(2)});
        auto l1 = m.add_instruction(load{create_shape(1), 0}, a1);
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(load{create_shape(1), 4}, a1);
        auto m2 = m.add_instruction(simple_op{}, l2);
        m.add_instruction(identity{}, a1, m1, m2);
        return m;
Paul's avatar
Paul committed
138
139
    };

Paul Fultz II's avatar
Paul Fultz II committed
140
141
142
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
Paul's avatar
Paul committed
143

Paul Fultz II's avatar
Paul Fultz II committed
144
    EXPECT(m1 == m2);
Paul's avatar
Paul committed
145
146
}

147
148
149
TEST_CASE(negative_axis1)
{
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
150
        migraphx::module m;
151

Paul Fultz II's avatar
Paul Fultz II committed
152
153
154
155
        auto a1          = m.add_instruction(allocate{create_shape(2, 2)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(2, 2)});
        auto m2          = m.add_instruction(simple_op{}, a2);
156
        std::size_t axis = -1;
Paul Fultz II's avatar
Paul Fultz II committed
157
158
159
        auto a3          = m.add_instruction(allocate{create_shape(4, 2)});
        m.add_instruction(concat(axis), m1, m2, a3);
        return m;
160
161
162
    };
    auto create_control_program = create_test_program;

Paul Fultz II's avatar
Paul Fultz II committed
163
164
165
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
166

Paul Fultz II's avatar
Paul Fultz II committed
167
    EXPECT(m1 == m2);
168
169
170
171
172
}

TEST_CASE(negative_axis2)
{
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
173
        migraphx::module m;
174

Paul Fultz II's avatar
Paul Fultz II committed
175
176
177
178
        auto a1          = m.add_instruction(allocate{create_shape(2, 2)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(2, 2)});
        auto m2          = m.add_instruction(simple_op{}, a2);
179
        std::size_t axis = -2;
Paul Fultz II's avatar
Paul Fultz II committed
180
181
182
        auto a3          = m.add_instruction(allocate{create_shape(4, 2)});
        m.add_instruction(concat(axis), m1, m2, a3);
        return m;
183
184
    };
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
185
186
187
188
189
190
191
192
193
        migraphx::module m;

        auto a1 = m.add_instruction(allocate{create_shape(4, 2)});
        auto l1 = m.add_instruction(load{create_shape(2, 2), 0}, a1);
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(load{create_shape(2, 2), 16}, a1);
        auto m2 = m.add_instruction(simple_op{}, l2);
        m.add_instruction(identity{}, a1, m1, m2);
        return m;
194
195
    };

Paul Fultz II's avatar
Paul Fultz II committed
196
197
198
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
199

Paul Fultz II's avatar
Paul Fultz II committed
200
    EXPECT(m1 == m2);
201
202
203
204
205
}

TEST_CASE(negative_axis3)
{
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
206
        migraphx::module m;
207

Paul Fultz II's avatar
Paul Fultz II committed
208
209
210
211
        auto a1          = m.add_instruction(allocate{create_shape(1, 2, 2)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(1, 2, 2)});
        auto m2          = m.add_instruction(simple_op{}, a2);
212
        std::size_t axis = -2;
Paul Fultz II's avatar
Paul Fultz II committed
213
214
215
        auto a3          = m.add_instruction(allocate{create_shape(1, 4, 2)});
        m.add_instruction(concat(axis), m1, m2, a3);
        return m;
216
217
    };
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
218
219
220
221
222
223
224
225
226
        migraphx::module m;

        auto a1 = m.add_instruction(allocate{create_shape(1, 4, 2)});
        auto l1 = m.add_instruction(load{create_shape(1, 2, 2), 0}, a1);
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(load{create_shape(1, 2, 2), 16}, a1);
        auto m2 = m.add_instruction(simple_op{}, l2);
        m.add_instruction(identity{}, a1, m1, m2);
        return m;
227
228
    };

Paul Fultz II's avatar
Paul Fultz II committed
229
230
231
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
232

Paul Fultz II's avatar
Paul Fultz II committed
233
    EXPECT(m1 == m2);
234
235
}

236
237
238
TEST_CASE(reversed)
{
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
239
        migraphx::module m;
240

Paul Fultz II's avatar
Paul Fultz II committed
241
242
243
244
        auto a1          = m.add_instruction(allocate{create_shape(1)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(1)});
        auto m2          = m.add_instruction(simple_op{}, a2);
245
        std::size_t axis = 0;
Paul Fultz II's avatar
Paul Fultz II committed
246
247
248
        auto a3          = m.add_instruction(allocate{create_shape(2)});
        m.add_instruction(concat(axis), m2, m1, a3);
        return m;
249
250
    };
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
251
252
253
254
255
256
257
258
259
        migraphx::module m;

        auto a1 = m.add_instruction(allocate{create_shape(2)});
        auto l1 = m.add_instruction(load{create_shape(1), 4}, a1);
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(load{create_shape(1), 0}, a1);
        auto m2 = m.add_instruction(simple_op{}, l2);
        m.add_instruction(identity{}, a1, m2, m1);
        return m;
260
261
    };

Paul Fultz II's avatar
Paul Fultz II committed
262
263
264
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
265

Paul Fultz II's avatar
Paul Fultz II committed
266
    EXPECT(m1 == m2);
267
268
}

Paul's avatar
Paul committed
269
270
TEST_CASE(nested)
{
Paul Fultz II's avatar
Paul Fultz II committed
271
272
273
274
275
    auto concat_test_program = [](auto& m) {
        auto a1          = m.add_instruction(allocate{create_shape(1)});
        auto m1          = m.add_instruction(simple_op{}, a1);
        auto a2          = m.add_instruction(allocate{create_shape(1)});
        auto m2          = m.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
276
        std::size_t axis = 0;
Paul Fultz II's avatar
Paul Fultz II committed
277
278
        auto a3          = m.add_instruction(allocate{create_shape(2)});
        return m.add_instruction(concat(axis), m1, m2, a3);
Paul's avatar
Paul committed
279
280
    };
    auto create_test_program = [&] {
Paul Fultz II's avatar
Paul Fultz II committed
281
282
283
        migraphx::module m;
        auto concat1     = concat_test_program(m);
        auto concat2     = concat_test_program(m);
Paul's avatar
Paul committed
284
        std::size_t axis = 0;
Paul Fultz II's avatar
Paul Fultz II committed
285
286
287
        auto a1          = m.add_instruction(allocate{create_shape(4)});
        m.add_instruction(concat(axis), concat1, concat2, a1);
        return m;
Paul's avatar
Paul committed
288
    };
Paul Fultz II's avatar
Paul Fultz II committed
289
290
291
292
293
294
    auto concat_control_program = [](auto& m, auto a1) {
        auto l1 = m.add_instruction(load{create_shape(1), 0}, a1);
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(load{create_shape(1), 4}, a1);
        auto m2 = m.add_instruction(simple_op{}, l2);
        return m.add_instruction(identity{}, a1, m1, m2);
Paul's avatar
Paul committed
295
296
    };
    auto create_control_program = [&] {
Paul Fultz II's avatar
Paul Fultz II committed
297
298
299
300
301
302
303
304
        migraphx::module m;
        auto a1      = m.add_instruction(allocate{create_shape(4)});
        auto l1      = m.add_instruction(load{create_shape(2), 0}, a1);
        auto concat1 = concat_control_program(m, l1);
        auto l2      = m.add_instruction(load{create_shape(2), 8}, a1);
        auto concat2 = concat_control_program(m, l2);
        m.add_instruction(identity{}, a1, concat1, concat2);
        return m;
Paul's avatar
Paul committed
305
306
    };

Paul Fultz II's avatar
Paul Fultz II committed
307
308
309
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
Paul's avatar
Paul committed
310

Paul Fultz II's avatar
Paul Fultz II committed
311
    EXPECT(m1 == m2);
Paul's avatar
Paul committed
312
313
}

Paul's avatar
Paul committed
314
TEST_CASE(basic)
315
{
Paul's avatar
Paul committed
316
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
317
318
319
320
321
322
323
324
325
326
        migraphx::module m;
        auto a1 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
        auto m1 = m.add_instruction(simple_op{}, a1);
        auto a2 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
        auto m2 = m.add_instruction(simple_op{}, a2);
        auto a3 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
        auto p3          = m.add_instruction(simple_op{}, a3);
327
        std::size_t axis = 1;
Paul Fultz II's avatar
Paul Fultz II committed
328
        auto a4          = m.add_instruction(
Paul's avatar
Paul committed
329
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Paul Fultz II's avatar
Paul Fultz II committed
330
331
        m.add_instruction(concat(axis), m1, m2, p3, a4);
        return m;
332
    };
Paul's avatar
Paul committed
333
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
334
335
        migraphx::module m;
        auto a1 = m.add_instruction(
Paul's avatar
Paul committed
336
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Paul Fultz II's avatar
Paul Fultz II committed
337
        auto l1 = m.add_instruction(
Paul's avatar
Paul committed
338
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
Paul Fultz II's avatar
Paul Fultz II committed
339
340
        auto m1 = m.add_instruction(simple_op{}, l1);
        auto l2 = m.add_instruction(
Paul's avatar
Paul committed
341
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
Paul Fultz II's avatar
Paul Fultz II committed
342
343
        auto m2 = m.add_instruction(simple_op{}, l2);
        auto l3 = m.add_instruction(
Paul's avatar
Paul committed
344
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
Paul Fultz II's avatar
Paul Fultz II committed
345
346
347
        auto p3 = m.add_instruction(simple_op{}, l3);
        m.add_instruction(identity{}, {a1, m1, m2, p3});
        return m;
348
349
    };

Paul Fultz II's avatar
Paul Fultz II committed
350
351
352
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
353

Paul Fultz II's avatar
Paul Fultz II committed
354
    EXPECT(m1 == m2);
355
356
}

Paul's avatar
Paul committed
357
TEST_CASE(wont_work)
358
{
Paul's avatar
Paul committed
359
    auto create_test_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
360
361
362
363
364
365
366
367
368
369
        migraphx::module m;
        auto a1 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto m1 = m.add_instruction(simple_op{}, a1);
        auto a2 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto m2 = m.add_instruction(simple_op{}, a2);
        auto a3 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = m.add_instruction(simple_op{}, a3);
370
        std::size_t axis = 1;
Paul Fultz II's avatar
Paul Fultz II committed
371
        auto a4          = m.add_instruction(
Paul's avatar
Paul committed
372
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
Paul Fultz II's avatar
Paul Fultz II committed
373
374
        m.add_instruction(concat(axis), m1, m2, p3, a4);
        return m;
375
    };
Paul's avatar
Paul committed
376
    auto create_control_program = [] {
Paul Fultz II's avatar
Paul Fultz II committed
377
378
379
380
381
382
383
384
385
386
        migraphx::module m;
        auto a1 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
        auto m1 = m.add_instruction(simple_op{}, a1);
        auto a2 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
        auto m2 = m.add_instruction(simple_op{}, a2);
        auto a3 =
            m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = m.add_instruction(simple_op{}, a3);
387
        std::size_t axis = 1;
Paul Fultz II's avatar
Paul Fultz II committed
388
        auto a4          = m.add_instruction(
Paul's avatar
Paul committed
389
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
Paul Fultz II's avatar
Paul Fultz II committed
390
391
        m.add_instruction(concat(axis), m1, m2, p3, a4);
        return m;
392
393
    };

Paul Fultz II's avatar
Paul Fultz II committed
394
395
396
    auto m1 = create_test_program();
    auto m2 = create_control_program();
    run_pass(m1);
397

Paul Fultz II's avatar
Paul Fultz II committed
398
    EXPECT(m1 == m2);
399
400
}

Paul's avatar
Paul committed
401
int main(int argc, const char* argv[]) { test::run(argc, argv); }