eliminate_concat_test.cpp 6.13 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
4
5
6
7
8
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
9
    concat(std::size_t axis) { op.axis = axis; }
Paul's avatar
Paul committed
10
    migraphx::op::concat op;
11
    std::string name() const { return "eliminate_concat::concat"; }
Paul's avatar
Paul committed
12
    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
13
    {
wsttiger's avatar
wsttiger committed
14
        return op.compute_shape(std::move(inputs));
15
    }
Paul's avatar
Paul committed
16
17
18
    migraphx::argument compute(migraphx::context&,
                              const migraphx::shape& output_shape,
                              const std::vector<migraphx::argument>&) const
19
20
21
22
23
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
24
struct concat_test_optimization
25
26
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
27
    std::string name() const { return "eliminate_concat::concat"; }
28
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
29
    std::string allocate() const { return "allocate"; }
30
    /// Return the lowered concat operator
Paul's avatar
Paul committed
31
    migraphx::op::concat get_concat(const migraphx::operation& op) const
32
    {
Paul's avatar
Paul committed
33
        return migraphx::any_cast<concat>(op).op;
34
35
36
37
38
39
40
    }
};

struct eliminate_concat_target
{
    std::size_t align = 32;
    std::string name() const { return "eliminate_target"; }
Paul's avatar
Paul committed
41
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
42
    {
Paul's avatar
Paul committed
43
44
        return {migraphx::eliminate_concat{concat_test_optimization{}},
                migraphx::dead_code_elimination{}};
45
    }
Paul's avatar
Paul committed
46
    migraphx::context get_context() const { return {}; }
47
48
49
50
};

struct allocate
{
Paul's avatar
Paul committed
51
    migraphx::shape s{};
52
    std::string name() const { return "allocate"; }
Paul's avatar
Paul committed
53
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
54
    {
Paul's avatar
Paul committed
55
        migraphx::check_shapes{inputs}.has(0);
56
57
        return s;
    }
Paul's avatar
Paul committed
58
59
60
    migraphx::argument compute(migraphx::context&,
                              const migraphx::shape& output_shape,
                              const std::vector<migraphx::argument>&) const
61
62
63
64
65
66
67
68
    {
        return {output_shape};
    }
};

struct fred_op
{
    std::string name() const { return "fred_op"; }
Paul's avatar
Paul committed
69
    migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
70
    {
Paul's avatar
Paul committed
71
        migraphx::check_shapes{inputs}.has(1);
72
73
        return inputs.at(0);
    }
Paul's avatar
Paul committed
74
75
76
    migraphx::argument compute(migraphx::context&,
                              const migraphx::shape&,
                              const std::vector<migraphx::argument>& args) const
77
78
79
80
81
    {
        return args.at(0);
    }
};

Paul's avatar
Paul committed
82
TEST_CASE(basic)
83
84
{
    auto create_test_program = []() {
Paul's avatar
Paul committed
85
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
86
        auto a1 =
Paul's avatar
Paul committed
87
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
88
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
89
        auto a2 =
Paul's avatar
Paul committed
90
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
91
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
92
        auto a3 =
Paul's avatar
Paul committed
93
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
94
        auto p3          = p.add_instruction(fred_op{}, a3);
95
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
96
        auto a4 =
Paul's avatar
Paul committed
97
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
98
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
99
        return p;
100
101
    };
    auto create_control_program = []() {
Paul's avatar
Paul committed
102
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
103
        auto a1 =
Paul's avatar
Paul committed
104
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
105
        auto l1 = p.add_instruction(
Paul's avatar
Paul committed
106
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
107
        auto p1 = p.add_instruction(fred_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
108
        auto l2 = p.add_instruction(
Paul's avatar
Paul committed
109
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
110
        auto p2 = p.add_instruction(fred_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
111
        auto l3 = p.add_instruction(
Paul's avatar
Paul committed
112
            migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280},
wsttiger's avatar
wsttiger committed
113
            {a1});
114
        auto p3 = p.add_instruction(fred_op{}, l3);
Paul's avatar
Paul committed
115
        p.add_instruction(migraphx::op::identity{}, {a1, p1, p2, p3});
116
117
118
119
120
121
122
123
124
125
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
126
TEST_CASE(wont_work)
127
128
{
    auto create_test_program = []() {
Paul's avatar
Paul committed
129
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
130
        auto a1 =
Paul's avatar
Paul committed
131
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
132
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
133
        auto a2 =
Paul's avatar
Paul committed
134
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
135
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
136
        auto a3 =
Paul's avatar
Paul committed
137
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
138
        auto p3          = p.add_instruction(fred_op{}, a3);
139
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
140
        auto a4 =
Paul's avatar
Paul committed
141
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
142
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
143
        return p;
144
145
    };
    auto create_control_program = []() {
Paul's avatar
Paul committed
146
        migraphx::program p;
Scott Thornton's avatar
Scott Thornton committed
147
        auto a1 =
Paul's avatar
Paul committed
148
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
149
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
150
        auto a2 =
Paul's avatar
Paul committed
151
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
152
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
153
        auto a3 =
Paul's avatar
Paul committed
154
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
Scott Thornton's avatar
Scott Thornton committed
155
        auto p3          = p.add_instruction(fred_op{}, a3);
156
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
157
        auto a4 =
Paul's avatar
Paul committed
158
            p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
159
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
160
        return p;
161
162
163
164
165
166
167
168
169
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
170
int main(int argc, const char* argv[]) { test::run(argc, argv); }