"driver/vscode:/vscode.git/clone" did not exist on "19a93dac051f3b5200fe00151b8fa5994aa890dd"
Unverified Commit 9169fbb3 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Rewrite avepooling bug (#584)



* fix a bug in rewrite_pooling pass

* clang format

* add unit tests for rewrite_pooling

* clang format

* add rewrite pooling to support maxpooling

* clang format

* remove a redundant unit test

* add one more unit test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 177deb2c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
namespace migraphx { namespace migraphx {
...@@ -15,28 +16,40 @@ void rewrite_pooling::apply(program& prog) const ...@@ -15,28 +16,40 @@ void rewrite_pooling::apply(program& prog) const
{ {
if(ins->name() != "pooling") if(ins->name() != "pooling")
continue; continue;
if(ins->get_shape().lens().size() != 4)
continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard()) if(not s.standard())
continue; continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average") if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(op.padding[0] != 0 and op.padding[1] != 0)
continue; continue;
if(op.stride[0] != 1 and op.stride[1] != 1) if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue; continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1]) auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front()); prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape); instruction_ref pooling{};
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
// average pooling
if(op.mode == "average")
{
pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
}
// max pooling
else
{
pooling = prog.insert_instruction(ins, op::reduce_max{{1}}, reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
prog.replace_instruction(ins, op::reshape{rsp_lens}, pooling);
} }
} }
......
...@@ -1188,6 +1188,19 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d> ...@@ -1188,6 +1188,19 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d>
} }
}; };
struct test_avg_pooling_3d_opt : verify_program<test_avg_pooling_3d_opt>
{
migraphx::program create_program() const
{
migraphx::program p;
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}});
auto op = migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
p.add_instruction(op, input);
return p;
}
};
struct test_gemm : verify_program<test_gemm> struct test_gemm : verify_program<test_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
static void opt_pooling(migraphx::program& prog)
{
migraphx::rewrite_pooling rp;
migraphx::dead_code_elimination dce;
rp.apply(prog);
dce.apply(prog);
}
TEST_CASE(rewrite_pooling_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&](const std::string& mode) {
migraphx::program p;
auto input = p.add_parameter("x", s);
auto ret =
p.add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}}, input);
p.add_return({ret});
return p;
};
auto opt_program = [&](const migraphx::operation& reduce_op) {
migraphx::program p;
auto input = p.add_parameter("x", s);
auto rsp = p.add_instruction(migraphx::op::reshape{{4, -1}}, input);
auto rdm = p.add_instruction(reduce_op, rsp);
auto ret = p.add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
p.add_return({ret});
return p;
};
auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode);
migraphx::program p2 = opt_program(op);
opt_pooling(p1);
EXPECT(p1 == p2);
};
test_rewrite("average", migraphx::op::reduce_mean{{1}});
test_rewrite("max", migraphx::op::reduce_max{{1}});
}
TEST_CASE(rewrite_avepooling_na1_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() {
migraphx::program p;
auto input = p.add_parameter("x", s);
auto ret = p.add_instruction(
migraphx::op::pooling{"average", {0, 1, 0}, {1, 1, 1}, {3, 4, 5}}, input);
p.add_return({ret});
return p;
};
migraphx::program p1 = pooling_program();
migraphx::program p2 = p1;
opt_pooling(p1);
EXPECT(p1 == p2);
}
TEST_CASE(rewrite_avepooling_na2_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() {
migraphx::program p;
auto input = p.add_parameter("x", s);
auto ret = p.add_instruction(
migraphx::op::pooling{"average", {0, 0, 0}, {1, 2, 1}, {3, 4, 5}}, input);
p.add_return({ret});
return p;
};
migraphx::program p1 = pooling_program();
migraphx::program p2 = p1;
opt_pooling(p1);
EXPECT(p1 == p2);
}
TEST_CASE(rewrite_avepooling_na3_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&]() {
migraphx::program p;
auto input = p.add_parameter("x", s);
auto ret =
p.add_instruction(migraphx::op::pooling{"max", {0, 0, 0}, {1, 1, 1}, {3, 3, 5}}, input);
p.add_return({ret});
return p;
};
migraphx::program p1 = pooling_program();
migraphx::program p2 = p1;
opt_pooling(p1);
EXPECT(p1 == p2);
}
TEST_CASE(literal_rewrite_pooling_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
std::vector<float> data(s.elements());
std::iota(data.begin(), data.end(), 1.0f);
auto pooling_program = [&](const std::string& mode) {
migraphx::program p;
auto input = p.add_literal(migraphx::literal(s, data));
auto ret =
p.add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}}, input);
p.add_return({ret});
return p;
};
auto opt_program = [&](const migraphx::operation& op) {
migraphx::program p;
auto input = p.add_literal(migraphx::literal(s, data));
auto rsp = p.add_instruction(migraphx::op::reshape{{4, -1}}, input);
auto rdm = p.add_instruction(op, rsp);
auto ret = p.add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm);
p.add_return({ret});
return p;
};
auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode);
migraphx::program p2 = opt_program(op);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1,
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
};
test_rewrite_pooling("max", migraphx::op::reduce_max{{1}});
test_rewrite_pooling("average", migraphx::op::reduce_mean{{1}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment