"requirements-test-rocm.txt" did not exist on "80665cd13cc1afbb244f144310d33b198dd7b124"
eliminate_concat_test.cpp 10.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
7
8
9
10
11
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
12
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
13
    migraphx::op::concat op;
14
15
16
17
18
19
20

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

21
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
22
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
23
    {
wsttiger's avatar
wsttiger committed
24
        return op.compute_shape(std::move(inputs));
25
    }
Paul's avatar
Paul committed
26
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
27
28
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
29
30
31
32
33
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
34
struct concat_test_optimization
35
36
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
37
    std::string name() const { return "eliminate_concat::concat"; }
38
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
39
    std::string allocate() const { return "allocate"; }
40
    /// Return the lowered concat operator
Paul's avatar
Paul committed
41
    migraphx::op::concat get_concat(const migraphx::operation& op) const
42
    {
Paul's avatar
Paul committed
43
        return migraphx::any_cast<concat>(op).op;
44
45
46
    }
};

47
void run_pass(migraphx::program& p)
48
{
49
50
51
52
    migraphx::run_passes(p,
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
53
54
55

struct allocate
{
Paul's avatar
Paul committed
56
    migraphx::shape s{};
57
58
59
60
61
62
63

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

64
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
65
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
66
    {
Paul's avatar
Paul committed
67
        migraphx::check_shapes{inputs}.has(0);
68
69
        return s;
    }
Paul's avatar
Paul committed
70
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
71
72
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
73
74
75
76
77
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
78
struct simple_op
79
{
Paul's avatar
Paul committed
80
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
81
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
82
    {
Paul's avatar
Paul committed
83
        migraphx::check_shapes{inputs}.has(1);
84
85
        return inputs.at(0);
    }
Paul's avatar
Paul committed
86
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
87
88
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
89
90
91
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
92
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
93
94
};

Paul's avatar
Paul committed
95
template <class... Ts>
Paul's avatar
Paul committed
96
97
98
99
100
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
101
using load     = migraphx::op::load;
Paul's avatar
Paul committed
102
103
104
105
106
107
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
108
109
110
111
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
112
        std::size_t axis = 0;
Paul's avatar
Paul committed
113
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
114
115
116
117
118
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
119
120
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
121
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
122
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
123
124
125
126
127
128
129
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
130
    run_pass(p1);
Paul's avatar
Paul committed
131
132
133
134

    EXPECT(p1 == p2);
}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
TEST_CASE(reversed)
{
    auto create_test_program = [] {
        migraphx::program p;
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
        std::size_t axis = 0;
        auto a3          = p.add_instruction(allocate{create_shape(2)});
        p.add_instruction(concat(axis), p2, p1, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 4}, a1);
        auto p1 = p.add_instruction(simple_op{}, l1);
        auto l2 = p.add_instruction(load{create_shape(1), 0}, a1);
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p2, p1);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    run_pass(p1);

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
166
167
168
TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
Paul's avatar
Paul committed
169
170
171
172
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
173
        std::size_t axis = 0;
Paul's avatar
Paul committed
174
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
175
176
177
178
        return p.add_instruction(concat(axis), p1, p2, a3);
    };
    auto create_test_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
179
180
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
181
        std::size_t axis = 0;
Paul's avatar
Paul committed
182
        auto a1          = p.add_instruction(allocate{create_shape(4)});
Paul's avatar
Paul committed
183
184
185
186
        p.add_instruction(concat(axis), concat1, concat2, a1);
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
Paul's avatar
Paul committed
187
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
188
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
189
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
190
191
192
193
194
        auto p2 = p.add_instruction(simple_op{}, l2);
        return p.add_instruction(identity{}, a1, p1, p2);
    };
    auto create_control_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
195
196
        auto a1      = p.add_instruction(allocate{create_shape(4)});
        auto l1      = p.add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
197
        auto concat1 = concat_control_program(p, l1);
Paul's avatar
Paul committed
198
        auto l2      = p.add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
199
200
201
202
203
204
205
        auto concat2 = concat_control_program(p, l2);
        p.add_instruction(identity{}, a1, concat1, concat2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
206
    run_pass(p1);
Paul's avatar
Paul committed
207
208
209
210

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
211
TEST_CASE(basic)
212
{
Paul's avatar
Paul committed
213
    auto create_test_program = [] {
Paul's avatar
Paul committed
214
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
215
        auto a1 =
Paul's avatar
Paul committed
216
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
Paul's avatar
Paul committed
217
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
218
        auto a2 =
Paul's avatar
Paul committed
219
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
Paul's avatar
Paul committed
220
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
221
        auto a3 =
Paul's avatar
Paul committed
222
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Paul's avatar
Paul committed
223
        auto p3          = p.add_instruction(simple_op{}, a3);
224
        std::size_t axis = 1;
Paul's avatar
Paul committed
225
226
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
227
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
228
        return p;
229
    };
Paul's avatar
Paul committed
230
    auto create_control_program = [] {
Paul's avatar
Paul committed
231
        migraphx::program p;
Paul's avatar
Paul committed
232
233
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
234
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
235
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
Paul's avatar
Paul committed
236
        auto p1 = p.add_instruction(simple_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
237
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
238
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
Paul's avatar
Paul committed
239
        auto p2 = p.add_instruction(simple_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
240
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
241
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
Paul's avatar
Paul committed
242
243
        auto p3 = p.add_instruction(simple_op{}, l3);
        p.add_instruction(identity{}, {a1, p1, p2, p3});
244
245
246
247
248
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
249
    run_pass(p1);
250
251
252
253

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
254
TEST_CASE(wont_work)
255
{
Paul's avatar
Paul committed
256
    auto create_test_program = [] {
Paul's avatar
Paul committed
257
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
258
        auto a1 =
Paul's avatar
Paul committed
259
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
260
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
261
        auto a2 =
Paul's avatar
Paul committed
262
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
263
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
264
        auto a3 =
Paul's avatar
Paul committed
265
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
266
        auto p3          = p.add_instruction(simple_op{}, a3);
267
        std::size_t axis = 1;
Paul's avatar
Paul committed
268
269
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
270
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
271
        return p;
272
    };
Paul's avatar
Paul committed
273
    auto create_control_program = [] {
Paul's avatar
Paul committed
274
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
275
        auto a1 =
Paul's avatar
Paul committed
276
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
277
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
278
        auto a2 =
Paul's avatar
Paul committed
279
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
280
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
281
        auto a3 =
Paul's avatar
Paul committed
282
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
283
        auto p3          = p.add_instruction(simple_op{}, a3);
284
        std::size_t axis = 1;
Paul's avatar
Paul committed
285
286
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
287
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
288
        return p;
289
290
291
292
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
293
    run_pass(p1);
294
295
296
297

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
298
int main(int argc, const char* argv[]) { test::run(argc, argv); }