rewrite_pooling_test.cpp 6.64 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
4
#include <migraphx/ref/target.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
6
7
8
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
9
10
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
11
12
13
#include <migraphx/verify.hpp>

bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
Paul Fultz II's avatar
Paul Fultz II committed
14
static void opt_pooling(migraphx::module& m)
Shucai Xiao's avatar
Shucai Xiao committed
15
16
17
{
    migraphx::rewrite_pooling rp;
    migraphx::dead_code_elimination dce;
Paul Fultz II's avatar
Paul Fultz II committed
18
19
    rp.apply(m);
    dce.apply(m);
Shucai Xiao's avatar
Shucai Xiao committed
20
21
22
23
24
25
}

TEST_CASE(rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&](const std::string& mode) {
Paul Fultz II's avatar
Paul Fultz II committed
26
27
28
29
30
31
32
33
34
35
        migraphx::module m;
        auto input = m.add_parameter("x", s);
        auto ret   = m.add_instruction(migraphx::make_op("pooling",
                                                       {{"mode", mode},
                                                        {"padding", {0, 0, 0}},
                                                        {"stride", {1, 1, 1}},
                                                        {"lengths", {3, 4, 5}}}),
                                     input);
        m.add_return({ret});
        return m;
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
    };

    auto opt_program = [&](const migraphx::operation& reduce_op) {
Paul Fultz II's avatar
Paul Fultz II committed
39
40
41
42
        migraphx::module m;
        auto input = m.add_parameter("x", s);
        auto rsp   = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
        auto rdm   = m.add_instruction(reduce_op, rsp);
43
        auto ret =
Paul Fultz II's avatar
Paul Fultz II committed
44
45
46
            m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
        m.add_return({ret});
        return m;
Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
    };

    auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
Paul Fultz II's avatar
Paul Fultz II committed
50
51
52
53
        migraphx::module m1 = pooling_program(mode);
        migraphx::module m2 = opt_program(op);
        opt_pooling(m1);
        EXPECT(m1 == m2);
Shucai Xiao's avatar
Shucai Xiao committed
54
55
    };

56
57
    test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
    test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
Shucai Xiao's avatar
Shucai Xiao committed
58
59
60
61
62
63
}

TEST_CASE(rewrite_avepooling_na1_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
Paul Fultz II's avatar
Paul Fultz II committed
64
65
66
67
68
69
70
71
72
73
74
        migraphx::module m;

        auto input = m.add_parameter("x", s);
        auto ret   = m.add_instruction(migraphx::make_op("pooling",
                                                       {{"mode", "average"},
                                                        {"padding", {0, 1, 0}},
                                                        {"stride", {1, 1, 1}},
                                                        {"lengths", {3, 4, 5}}}),
                                     input);
        m.add_return({ret});
        return m;
Shucai Xiao's avatar
Shucai Xiao committed
75
76
    };

Paul Fultz II's avatar
Paul Fultz II committed
77
78
    migraphx::module m1 = pooling_program();
    migraphx::module m2 = m1;
79

Paul Fultz II's avatar
Paul Fultz II committed
80
81
    opt_pooling(m1);
    EXPECT(m1 == m2);
Shucai Xiao's avatar
Shucai Xiao committed
82
83
84
85
86
87
}

TEST_CASE(rewrite_avepooling_na2_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
Paul Fultz II's avatar
Paul Fultz II committed
88
89
90
91
92
93
94
95
96
97
98
        migraphx::module m;

        auto input = m.add_parameter("x", s);
        auto ret   = m.add_instruction(migraphx::make_op("pooling",
                                                       {{"mode", "average"},
                                                        {"padding", {0, 0, 0}},
                                                        {"stride", {1, 2, 1}},
                                                        {"lengths", {3, 4, 5}}}),
                                     input);
        m.add_return({ret});
        return m;
Shucai Xiao's avatar
Shucai Xiao committed
99
100
    };

Paul Fultz II's avatar
Paul Fultz II committed
101
102
    migraphx::module m1 = pooling_program();
    migraphx::module m2 = m1;
103

Paul Fultz II's avatar
Paul Fultz II committed
104
105
    opt_pooling(m1);
    EXPECT(m1 == m2);
Shucai Xiao's avatar
Shucai Xiao committed
106
107
108
109
110
111
}

TEST_CASE(rewrite_avepooling_na3_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
Paul Fultz II's avatar
Paul Fultz II committed
112
113
114
115
116
117
118
119
120
121
122
        migraphx::module m;

        auto input = m.add_parameter("x", s);
        auto ret   = m.add_instruction(migraphx::make_op("pooling",
                                                       {{"mode", "max"},
                                                        {"padding", {0, 0, 0}},
                                                        {"stride", {1, 1, 1}},
                                                        {"lengths", {3, 3, 5}}}),
                                     input);
        m.add_return({ret});
        return m;
Shucai Xiao's avatar
Shucai Xiao committed
123
124
    };

Paul Fultz II's avatar
Paul Fultz II committed
125
126
    migraphx::module m1 = pooling_program();
    migraphx::module m2 = m1;
127

Paul Fultz II's avatar
Paul Fultz II committed
128
129
    opt_pooling(m1);
    EXPECT(m1 == m2);
Shucai Xiao's avatar
Shucai Xiao committed
130
131
132
133
134
135
136
137
138
139
}

TEST_CASE(literal_rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    std::vector<float> data(s.elements());
    std::iota(data.begin(), data.end(), 1.0f);

    auto pooling_program = [&](const std::string& mode) {
        migraphx::program p;
140
141
142

        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
143
144
145
146
147
        auto ret   = mm->add_instruction(migraphx::make_op("pooling",
                                                         {{"mode", mode},
                                                          {"padding", {0, 0, 0}},
                                                          {"stride", {1, 1, 1}},
                                                          {"lengths", {3, 4, 5}}}),
148
149
                                       input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
150
151
152
153
154
        return p;
    };

    auto opt_program = [&](const migraphx::operation& op) {
        migraphx::program p;
155
156
        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
157
        auto rsp   = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
158
        auto rdm   = mm->add_instruction(op, rsp);
159
160
        auto ret =
            mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
161
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
162
163
164
165
166
167
168

        return p;
    };

    auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) {
        migraphx::program p1 = pooling_program(mode);
        migraphx::program p2 = opt_program(op);
169
170
        p1.compile(migraphx::ref::target{});
        p2.compile(migraphx::ref::target{});
Shucai Xiao's avatar
Shucai Xiao committed
171
172
173
174
175
176
        auto result1 = p1.eval({}).back();
        auto result2 = p2.eval({}).back();
        visit_all(result1,
                  result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
    };

177
178
    test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
    test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
Shucai Xiao's avatar
Shucai Xiao committed
179
180
181
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }