rewrite_pooling_test.cpp 6.97 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
4
#include <migraphx/ref/target.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
6
7
8
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
9
10
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
11
12
13
14
15
#include <migraphx/verify.hpp>

bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
static void opt_pooling(migraphx::program& prog)
{
16
    auto* mm = prog.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
17
18
    migraphx::rewrite_pooling rp;
    migraphx::dead_code_elimination dce;
19
20
    rp.apply(*mm);
    dce.apply(*mm);
Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24
25
26
27
}

TEST_CASE(rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&](const std::string& mode) {
        migraphx::program p;
28
29
        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
30
31
32
33
34
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", mode},
                                                          {"padding", {0, 0, 0}},
                                                          {"stride", {1, 1, 1}},
                                                          {"lengths", {3, 4, 5}}}),
35
36
                                       input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
37
38
39
40
41
        return p;
    };

    auto opt_program = [&](const migraphx::operation& reduce_op) {
        migraphx::program p;
42
43
        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
44
        auto rsp   = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
45
        auto rdm   = mm->add_instruction(reduce_op, rsp);
46
47
        auto ret =
            mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
48
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
54
55
56
57
58
        return p;
    };

    auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
        migraphx::program p1 = pooling_program(mode);
        migraphx::program p2 = opt_program(op);
        opt_pooling(p1);
        EXPECT(p1 == p2);
    };

59
60
    test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
    test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
Shucai Xiao's avatar
Shucai Xiao committed
61
62
63
64
65
66
67
}

TEST_CASE(rewrite_avepooling_na1_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
68
69
70

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
71
72
73
74
75
76
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", "average"},
                                                          {"padding", {0, 1, 0}},
                                                          {"stride", {1, 1, 1}},
                                                          {"lengths", {3, 4, 5}}}),
                                       input);
77
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
83

Shucai Xiao's avatar
Shucai Xiao committed
84
85
86
87
88
89
90
91
92
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(rewrite_avepooling_na2_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
93
94
95

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
96
97
98
99
100
101
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", "average"},
                                                          {"padding", {0, 0, 0}},
                                                          {"stride", {1, 2, 1}},
                                                          {"lengths", {3, 4, 5}}}),
                                       input);
102
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
106
107
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
108

Shucai Xiao's avatar
Shucai Xiao committed
109
110
111
112
113
114
115
116
117
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(rewrite_avepooling_na3_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
118
119
120

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
121
122
123
124
125
126
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", "max"},
                                                          {"padding", {0, 0, 0}},
                                                          {"stride", {1, 1, 1}},
                                                          {"lengths", {3, 3, 5}}}),
                                       input);
127
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
128
129
130
131
132
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
133

Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
138
139
140
141
142
143
144
145
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(literal_rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    std::vector<float> data(s.elements());
    std::iota(data.begin(), data.end(), 1.0f);

    auto pooling_program = [&](const std::string& mode) {
        migraphx::program p;
146
147
148

        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
149
150
151
152
153
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", mode},
                                                          {"padding", {0, 0, 0}},
                                                          {"stride", {1, 1, 1}},
                                                          {"lengths", {3, 4, 5}}}),
154
155
                                       input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
156
157
158
159
160
        return p;
    };

    auto opt_program = [&](const migraphx::operation& op) {
        migraphx::program p;
161
162
        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
163
        auto rsp   = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
164
        auto rdm   = mm->add_instruction(op, rsp);
165
166
        auto ret =
            mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
167
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
168
169
170
171
172
173
174

        return p;
    };

    auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) {
        migraphx::program p1 = pooling_program(mode);
        migraphx::program p2 = opt_program(op);
175
176
        p1.compile(migraphx::ref::target{});
        p2.compile(migraphx::ref::target{});
Shucai Xiao's avatar
Shucai Xiao committed
177
178
179
180
181
182
        auto result1 = p1.eval({}).back();
        auto result2 = p2.eval({}).back();
        visit_all(result1,
                  result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
    };

183
184
    test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
    test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
Shucai Xiao's avatar
Shucai Xiao committed
185
186
187
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }