eliminate_concat_test.cpp 9.43 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
4
5
6
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
7
8
9
10
11
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
12
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
13
    migraphx::op::concat op;
14
15
16
17
18
19
20

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

21
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
22
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
23
    {
wsttiger's avatar
wsttiger committed
24
        return op.compute_shape(std::move(inputs));
25
    }
Paul's avatar
Paul committed
26
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
27
28
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
29
30
31
32
33
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
34
struct concat_test_optimization
35
36
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
37
    std::string name() const { return "eliminate_concat::concat"; }
38
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
39
    std::string allocate() const { return "allocate"; }
40
    /// Return the lowered concat operator
Paul's avatar
Paul committed
41
    migraphx::op::concat get_concat(const migraphx::operation& op) const
42
    {
Paul's avatar
Paul committed
43
        return migraphx::any_cast<concat>(op).op;
44
45
46
    }
};

47
void run_pass(migraphx::program& p)
48
{
49
50
51
52
    migraphx::run_passes(p,
                         {migraphx::eliminate_concat{concat_test_optimization{}},
                          migraphx::dead_code_elimination{}});
}
53
54
55

struct allocate
{
Paul's avatar
Paul committed
56
    migraphx::shape s{};
57
58
59
60
61
62
63

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.s, "shape"));
    }

64
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
65
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
66
    {
Paul's avatar
Paul committed
67
        migraphx::check_shapes{inputs}.has(0);
68
69
        return s;
    }
Paul's avatar
Paul committed
70
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
71
72
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
73
74
75
76
77
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
78
struct simple_op
79
{
Paul's avatar
Paul committed
80
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
81
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
82
    {
Paul's avatar
Paul committed
83
        migraphx::check_shapes{inputs}.has(1);
84
85
        return inputs.at(0);
    }
Paul's avatar
Paul committed
86
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
87
88
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
89
90
91
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
92
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
93
94
};

Paul's avatar
Paul committed
95
template <class... Ts>
Paul's avatar
Paul committed
96
97
98
99
100
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
101
using load     = migraphx::op::load;
Paul's avatar
Paul committed
102
103
104
105
106
107
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
108
109
110
111
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
112
        std::size_t axis = 0;
Paul's avatar
Paul committed
113
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
114
115
116
117
118
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
119
120
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
121
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
122
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
123
124
125
126
127
128
129
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
130
    run_pass(p1);
Paul's avatar
Paul committed
131
132
133
134
135
136
137

    EXPECT(p1 == p2);
}

TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
Paul's avatar
Paul committed
138
139
140
141
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
142
        std::size_t axis = 0;
Paul's avatar
Paul committed
143
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
144
145
146
147
        return p.add_instruction(concat(axis), p1, p2, a3);
    };
    auto create_test_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
148
149
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
150
        std::size_t axis = 0;
Paul's avatar
Paul committed
151
        auto a1          = p.add_instruction(allocate{create_shape(4)});
Paul's avatar
Paul committed
152
153
154
155
        p.add_instruction(concat(axis), concat1, concat2, a1);
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
Paul's avatar
Paul committed
156
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
157
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
158
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
159
160
161
162
163
        auto p2 = p.add_instruction(simple_op{}, l2);
        return p.add_instruction(identity{}, a1, p1, p2);
    };
    auto create_control_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
164
165
        auto a1      = p.add_instruction(allocate{create_shape(4)});
        auto l1      = p.add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
166
        auto concat1 = concat_control_program(p, l1);
Paul's avatar
Paul committed
167
        auto l2      = p.add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
168
169
170
171
172
173
174
        auto concat2 = concat_control_program(p, l2);
        p.add_instruction(identity{}, a1, concat1, concat2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
175
    run_pass(p1);
Paul's avatar
Paul committed
176
177
178
179

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
180
TEST_CASE(basic)
181
{
Paul's avatar
Paul committed
182
    auto create_test_program = [] {
Paul's avatar
Paul committed
183
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
184
        auto a1 =
Paul's avatar
Paul committed
185
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
Paul's avatar
Paul committed
186
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
187
        auto a2 =
Paul's avatar
Paul committed
188
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
Paul's avatar
Paul committed
189
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
190
        auto a3 =
Paul's avatar
Paul committed
191
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Paul's avatar
Paul committed
192
        auto p3          = p.add_instruction(simple_op{}, a3);
193
        std::size_t axis = 1;
Paul's avatar
Paul committed
194
195
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
196
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
197
        return p;
198
    };
Paul's avatar
Paul committed
199
    auto create_control_program = [] {
Paul's avatar
Paul committed
200
        migraphx::program p;
Paul's avatar
Paul committed
201
202
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
203
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
204
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
Paul's avatar
Paul committed
205
        auto p1 = p.add_instruction(simple_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
206
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
207
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
Paul's avatar
Paul committed
208
        auto p2 = p.add_instruction(simple_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
209
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
210
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
Paul's avatar
Paul committed
211
212
        auto p3 = p.add_instruction(simple_op{}, l3);
        p.add_instruction(identity{}, {a1, p1, p2, p3});
213
214
215
216
217
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
218
    run_pass(p1);
219
220
221
222

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
223
TEST_CASE(wont_work)
224
{
Paul's avatar
Paul committed
225
    auto create_test_program = [] {
Paul's avatar
Paul committed
226
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
227
        auto a1 =
Paul's avatar
Paul committed
228
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
229
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
230
        auto a2 =
Paul's avatar
Paul committed
231
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
232
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
233
        auto a3 =
Paul's avatar
Paul committed
234
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
235
        auto p3          = p.add_instruction(simple_op{}, a3);
236
        std::size_t axis = 1;
Paul's avatar
Paul committed
237
238
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
239
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
240
        return p;
241
    };
Paul's avatar
Paul committed
242
    auto create_control_program = [] {
Paul's avatar
Paul committed
243
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
244
        auto a1 =
Paul's avatar
Paul committed
245
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
246
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
247
        auto a2 =
Paul's avatar
Paul committed
248
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
249
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
250
        auto a3 =
Paul's avatar
Paul committed
251
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
252
        auto p3          = p.add_instruction(simple_op{}, a3);
253
        std::size_t axis = 1;
Paul's avatar
Paul committed
254
255
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
256
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
257
        return p;
258
259
260
261
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
262
    run_pass(p1);
263
264
265
266

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
267
int main(int argc, const char* argv[]) { test::run(argc, argv); }