"test/verify/test_conv_bn_add.cpp" did not exist on "ba33d25cd3c5acd92d9a8a0c28abb45b288af4f2"
eliminate_concat_test.cpp 6.05 KB
Newer Older
1
2
3
4
5
6
7
8
#include <migraph/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>

struct concat
{
Scott Thornton's avatar
Scott Thornton committed
9
    concat(std::size_t axis) { op.axis = axis; }
10
11
12
13
    migraph::op::concat op;
    std::string name() const { return "eliminate_concat::concat"; }
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
    {
wsttiger's avatar
wsttiger committed
14
        return op.compute_shape(std::move(inputs));
15
    }
16
    migraph::argument compute(migraph::context&,
Scott Thornton's avatar
Scott Thornton committed
17
                              const migraph::shape& output_shape,
18
                              const std::vector<migraph::argument>&) const
19
20
21
22
23
    {
        return {output_shape};
    }
};

Scott Thornton's avatar
Scott Thornton committed
24
struct concat_test_optimization
25
26
{
    /// A unique name used to identify the concat optimization
Scott Thornton's avatar
Scott Thornton committed
27
    std::string name() const { return "eliminate_concat::concat"; }
28
    /// A unique name used to identify the allocate operator
Scott Thornton's avatar
Scott Thornton committed
29
    std::string allocate() const { return "allocate"; }
30
31
32
33
34
35
36
37
38
39
40
41
42
    /// Return the lowered concat operator
    migraph::op::concat get_concat(const migraph::operation& op) const
    {
        return migraph::any_cast<concat>(op).op;
    }
};

struct eliminate_concat_target
{
    std::size_t align = 32;
    std::string name() const { return "eliminate_target"; }
    std::vector<migraph::pass> get_passes(migraph::context&) const
    {
Scott Thornton's avatar
Scott Thornton committed
43
44
        return {migraph::eliminate_concat{concat_test_optimization{}},
                migraph::dead_code_elimination{}};
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    }
    migraph::context get_context() const { return {}; }
};

struct allocate
{
    migraph::shape s{};
    std::string name() const { return "allocate"; }
    migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
    {
        migraph::check_shapes{inputs}.has(0);
        return s;
    }
    migraph::argument compute(migraph::context&,
                              const migraph::shape& output_shape,
                              const std::vector<migraph::argument>&) const
    {
        return {output_shape};
    }
};

struct fred_op
{
    std::string name() const { return "fred_op"; }
    migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
    {
        migraph::check_shapes{inputs}.has(1);
        return inputs.at(0);
    }
    migraph::argument compute(migraph::context&,
75
                              const migraph::shape&,
76
77
78
79
80
81
                              const std::vector<migraph::argument>& args) const
    {
        return args.at(0);
    }
};

Paul's avatar
Paul committed
82
TEST_CASE(basic)
83
84
85
{
    auto create_test_program = []() {
        migraph::program p;
Scott Thornton's avatar
Scott Thornton committed
86
87
        auto a1 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}});
88
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
89
90
        auto a2 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}});
91
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
92
93
94
        auto a3 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}});
        auto p3          = p.add_instruction(fred_op{}, a3);
95
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
96
97
        auto a4 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
98
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
99
        return p;
100
101
102
    };
    auto create_control_program = []() {
        migraph::program p;
Scott Thornton's avatar
Scott Thornton committed
103
104
105
106
        auto a1 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
        auto l1 = p.add_instruction(
            migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
107
        auto p1 = p.add_instruction(fred_op{}, l1);
Scott Thornton's avatar
Scott Thornton committed
108
        auto l2 = p.add_instruction(
wsttiger's avatar
wsttiger committed
109
            migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
110
        auto p2 = p.add_instruction(fred_op{}, l2);
Scott Thornton's avatar
Scott Thornton committed
111
        auto l3 = p.add_instruction(
wsttiger's avatar
wsttiger committed
112
113
            migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 1280},
            {a1});
114
        auto p3 = p.add_instruction(fred_op{}, l3);
wsttiger's avatar
wsttiger committed
115
        p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
116
117
118
119
120
121
122
123
124
125
        return p;
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
126
TEST_CASE(wont_work)
127
128
129
{
    auto create_test_program = []() {
        migraph::program p;
Scott Thornton's avatar
Scott Thornton committed
130
131
        auto a1 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
132
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
133
134
        auto a2 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
135
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
136
137
138
        auto a3 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = p.add_instruction(fred_op{}, a3);
139
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
140
141
        auto a4 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
wsttiger's avatar
wsttiger committed
142
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
143
        return p;
144
145
146
    };
    auto create_control_program = []() {
        migraph::program p;
Scott Thornton's avatar
Scott Thornton committed
147
148
        auto a1 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
149
        auto p1 = p.add_instruction(fred_op{}, a1);
Scott Thornton's avatar
Scott Thornton committed
150
151
        auto a2 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
152
        auto p2 = p.add_instruction(fred_op{}, a2);
Scott Thornton's avatar
Scott Thornton committed
153
154
155
        auto a3 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
        auto p3          = p.add_instruction(fred_op{}, a3);
156
        std::size_t axis = 1;
Scott Thornton's avatar
Scott Thornton committed
157
158
        auto a4 =
            p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
159
        p.add_instruction(concat(axis), p1, p2, p3, a4);
Scott Thornton's avatar
Scott Thornton committed
160
        return p;
161
162
163
164
165
166
167
168
169
    };

    auto p1 = create_test_program();
    auto p2 = create_control_program();
    p1.compile(eliminate_concat_target{});

    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
170
int main(int argc, const char* argv[]) { test::run(argc, argv); }