eliminate_concat_test.cpp 9.39 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
4
5
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
6
7
8
9
10
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
11
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
12
    migraphx::op::concat op;
13
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
14
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
15
    {
wsttiger's avatar
wsttiger committed
16
        return op.compute_shape(std::move(inputs));
17
    }
Paul's avatar
Paul committed
18
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
19
20
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
21
22
23
24
25
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
26
struct concat_test_optimization
27
28
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
29
    std::string name() const { return "eliminate_concat::concat"; }
30
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
31
    std::string allocate() const { return "allocate"; }
32
    /// Return the lowered concat operator
Paul's avatar
Paul committed
33
    migraphx::op::concat get_concat(const migraphx::operation& op) const
34
    {
Paul's avatar
Paul committed
35
        return migraphx::any_cast<concat>(op).op;
36
37
38
39
40
41
42
    }
};

struct eliminate_concat_target
{
    std::size_t align = 32;
    std::string name() const { return "eliminate_target"; }
Paul's avatar
Paul committed
43
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
44
    {
Paul's avatar
Paul committed
45
46
        return {migraphx::eliminate_concat{concat_test_optimization{}},
                migraphx::dead_code_elimination{}};
47
    }
Paul's avatar
Paul committed
48
    migraphx::context get_context() const { return {}; }
49
50
51
52
};

struct allocate
{
Paul's avatar
Paul committed
53
    migraphx::shape s{};
54
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
55
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
56
    {
Paul's avatar
Paul committed
57
        migraphx::check_shapes{inputs}.has(0);
58
59
        return s;
    }
Paul's avatar
Paul committed
60
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
61
62
                               const migraphx::shape& output_shape,
                               const std::vector<migraphx::argument>&) const
63
64
65
66
67
    {
        return {output_shape};
    }
};

Paul's avatar
Paul committed
68
struct simple_op
69
{
Paul's avatar
Paul committed
70
    std::string name() const { return "simple_op"; }
Paul's avatar
Paul committed
71
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
72
    {
Paul's avatar
Paul committed
73
        migraphx::check_shapes{inputs}.has(1);
74
75
        return inputs.at(0);
    }
Paul's avatar
Paul committed
76
    migraphx::argument compute(migraphx::context&,
Paul's avatar
Paul committed
77
78
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>& args) const
79
80
81
    {
        return args.at(0);
    }
Paul's avatar
Paul committed
82
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
83
84
};

Paul's avatar
Paul committed
85
template <class... Ts>
Paul's avatar
Paul committed
86
87
88
89
90
migraphx::shape create_shape(Ts... xs)
{
    return migraphx::shape{migraphx::shape::float_type, {std::size_t(xs)...}};
}

Paul's avatar
Paul committed
91
using load     = migraphx::op::load;
Paul's avatar
Paul committed
92
93
94
95
96
97
using identity = migraphx::op::identity;

TEST_CASE(simple)
{
    auto create_test_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
98
99
100
101
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
102
        std::size_t axis = 0;
Paul's avatar
Paul committed
103
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
104
105
106
107
108
        p.add_instruction(concat(axis), p1, p2, a3);
        return p;
    };
    auto create_control_program = [] {
        migraphx::program p;
Paul's avatar
Paul committed
109
110
        auto a1 = p.add_instruction(allocate{create_shape(2)});
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
111
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
112
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        auto p2 = p.add_instruction(simple_op{}, l2);
        p.add_instruction(identity{}, a1, p1, p2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

TEST_CASE(nested)
{
    auto concat_test_program = [](auto& p) {
Paul's avatar
Paul committed
128
129
130
131
        auto a1          = p.add_instruction(allocate{create_shape(1)});
        auto p1          = p.add_instruction(simple_op{}, a1);
        auto a2          = p.add_instruction(allocate{create_shape(1)});
        auto p2          = p.add_instruction(simple_op{}, a2);
Paul's avatar
Paul committed
132
        std::size_t axis = 0;
Paul's avatar
Paul committed
133
        auto a3          = p.add_instruction(allocate{create_shape(2)});
Paul's avatar
Paul committed
134
135
136
137
        return p.add_instruction(concat(axis), p1, p2, a3);
    };
    auto create_test_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
138
139
        auto concat1     = concat_test_program(p);
        auto concat2     = concat_test_program(p);
Paul's avatar
Paul committed
140
        std::size_t axis = 0;
Paul's avatar
Paul committed
141
        auto a1          = p.add_instruction(allocate{create_shape(4)});
Paul's avatar
Paul committed
142
143
144
145
        p.add_instruction(concat(axis), concat1, concat2, a1);
        return p;
    };
    auto concat_control_program = [](auto& p, auto a1) {
Paul's avatar
Paul committed
146
        auto l1 = p.add_instruction(load{create_shape(1), 0}, a1);
Paul's avatar
Paul committed
147
        auto p1 = p.add_instruction(simple_op{}, l1);
Paul's avatar
Paul committed
148
        auto l2 = p.add_instruction(load{create_shape(1), 4}, a1);
Paul's avatar
Paul committed
149
150
151
152
153
        auto p2 = p.add_instruction(simple_op{}, l2);
        return p.add_instruction(identity{}, a1, p1, p2);
    };
    auto create_control_program = [&] {
        migraphx::program p;
Paul's avatar
Paul committed
154
155
        auto a1      = p.add_instruction(allocate{create_shape(4)});
        auto l1      = p.add_instruction(load{create_shape(2), 0}, a1);
Paul's avatar
Paul committed
156
        auto concat1 = concat_control_program(p, l1);
Paul's avatar
Paul committed
157
        auto l2      = p.add_instruction(load{create_shape(2), 8}, a1);
Paul's avatar
Paul committed
158
159
160
161
162
163
164
165
166
167
168
169
        auto concat2 = concat_control_program(p, l2);
        p.add_instruction(identity{}, a1, concat1, concat2);
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
170
TEST_CASE(basic)
171
{
Paul's avatar
Paul committed
172
    auto create_test_program = [] {
Paul's avatar
Paul committed
173
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
174
        auto a1 =
Paul's avatar
Paul committed
175
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
Paul's avatar
Paul committed
176
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
177
        auto a2 =
Paul's avatar
Paul committed
178
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
Paul's avatar
Paul committed
179
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
180
        auto a3 =
Paul's avatar
Paul committed
181
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Paul's avatar
Paul committed
182
        auto p3          = p.add_instruction(simple_op{}, a3);
183
        std::size_t axis = 1;
Paul's avatar
Paul committed
184
185
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
186
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
187
        return p;
188
    };
Paul's avatar
Paul committed
189
    auto create_control_program = [] {
Paul's avatar
Paul committed
190
        migraphx::program p;
Paul's avatar
Paul committed
191
192
        auto a1 = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
193
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
194
            load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
Paul's avatar
Paul committed
195
        auto p1 = p.add_instruction(simple_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
196
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
197
            load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
Paul's avatar
Paul committed
198
        auto p2 = p.add_instruction(simple_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
199
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
200
            load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1});
Paul's avatar
Paul committed
201
202
        auto p3 = p.add_instruction(simple_op{}, l3);
        p.add_instruction(identity{}, {a1, p1, p2, p3});
203
204
205
206
207
208
209
210
211
212
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
213
TEST_CASE(wont_work)
214
{
Paul's avatar
Paul committed
215
    auto create_test_program = [] {
Paul's avatar
Paul committed
216
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
217
        auto a1 =
Paul's avatar
Paul committed
218
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
219
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
220
        auto a2 =
Paul's avatar
Paul committed
221
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
222
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
223
        auto a3 =
Paul's avatar
Paul committed
224
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
225
        auto p3          = p.add_instruction(simple_op{}, a3);
226
        std::size_t axis = 1;
Paul's avatar
Paul committed
227
228
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
229
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
230
        return p;
231
    };
Paul's avatar
Paul committed
232
    auto create_control_program = [] {
Paul's avatar
Paul committed
233
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
234
        auto a1 =
Paul's avatar
Paul committed
235
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
Paul's avatar
Paul committed
236
        auto p1 = p.add_instruction(simple_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
237
        auto a2 =
Paul's avatar
Paul committed
238
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
Paul's avatar
Paul committed
239
        auto p2 = p.add_instruction(simple_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
240
        auto a3 =
Paul's avatar
Paul committed
241
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Paul's avatar
Paul committed
242
        auto p3          = p.add_instruction(simple_op{}, a3);
243
        std::size_t axis = 1;
Paul's avatar
Paul committed
244
245
        auto a4          = p.add_instruction(
            allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
246
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
247
        return p;
248
249
250
251
252
253
254
255
256
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
257
int main(int argc, const char* argv[]) { test::run(argc, argv); }