rewrite_pooling_test.cpp 5.43 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
4
#include <migraphx/ref/target.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>

bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
static void opt_pooling(migraphx::program& prog)
{
18
    auto* mm = prog.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
19
20
    migraphx::rewrite_pooling rp;
    migraphx::dead_code_elimination dce;
21
22
    rp.apply(*mm);
    dce.apply(*mm);
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
29
}

TEST_CASE(rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&](const std::string& mode) {
        migraphx::program p;
30
31
32
33
34
        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
        auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}},
                                       input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
35
36
37
38
39
        return p;
    };

    auto opt_program = [&](const migraphx::operation& reduce_op) {
        migraphx::program p;
40
41
42
43
44
45
        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
        auto rsp   = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input);
        auto rdm   = mm->add_instruction(reduce_op, rsp);
        auto ret   = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        return p;
    };

    auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
        migraphx::program p1 = pooling_program(mode);
        migraphx::program p2 = opt_program(op);
        opt_pooling(p1);
        EXPECT(p1 == p2);
    };

    test_rewrite("average", migraphx::op::reduce_mean{{1}});
    test_rewrite("max", migraphx::op::reduce_max{{1}});
}

TEST_CASE(rewrite_avepooling_na1_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
65
66
67
68

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
        auto ret   = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
69
            migraphx::op::pooling{"average", {0, 1, 0}, {1, 1, 1}, {3, 4, 5}}, input);
70
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
71
72
73
74
75
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
76

Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
80
81
82
83
84
85
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(rewrite_avepooling_na2_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
86
87
88
89

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
        auto ret   = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
90
            migraphx::op::pooling{"average", {0, 0, 0}, {1, 2, 1}, {3, 4, 5}}, input);
91
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
95
96
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
97

Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
101
102
103
104
105
106
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(rewrite_avepooling_na3_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    auto pooling_program = [&]() {
        migraphx::program p;
107
108
109
110
111
112

        auto* mm   = p.get_main_module();
        auto input = mm->add_parameter("x", s);
        auto ret   = mm->add_instruction(
            migraphx::op::pooling{"max", {0, 0, 0}, {1, 1, 1}, {3, 3, 5}}, input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
        return p;
    };

    migraphx::program p1 = pooling_program();
    migraphx::program p2 = p1;
118

Shucai Xiao's avatar
Shucai Xiao committed
119
120
121
122
123
124
125
126
127
128
129
130
    opt_pooling(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(literal_rewrite_pooling_test)
{
    migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
    std::vector<float> data(s.elements());
    std::iota(data.begin(), data.end(), 1.0f);

    auto pooling_program = [&](const std::string& mode) {
        migraphx::program p;
131
132
133
134
135
136

        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
        auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}},
                                       input);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
137
138
139
140
141
        return p;
    };

    auto opt_program = [&](const migraphx::operation& op) {
        migraphx::program p;
142
143
144
145
146
147
        auto* mm   = p.get_main_module();
        auto input = mm->add_literal(migraphx::literal(s, data));
        auto rsp   = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input);
        auto rdm   = mm->add_instruction(op, rsp);
        auto ret   = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
        mm->add_return({ret});
Shucai Xiao's avatar
Shucai Xiao committed
148
149
150
151
152
153
154

        return p;
    };

    auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) {
        migraphx::program p1 = pooling_program(mode);
        migraphx::program p2 = opt_program(op);
155
156
        p1.compile(migraphx::ref::target{});
        p2.compile(migraphx::ref::target{});
Shucai Xiao's avatar
Shucai Xiao committed
157
158
159
160
161
162
163
164
165
166
167
        auto result1 = p1.eval({}).back();
        auto result2 = p2.eval({}).back();
        visit_all(result1,
                  result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
    };

    test_rewrite_pooling("max", migraphx::op::reduce_max{{1}});
    test_rewrite_pooling("average", migraphx::op::reduce_mean{{1}});
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }