Commit 78c83426 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Fix accuraccy issue in resnet50 (#395)

* Fix bug in eliminate_concat

* Formatting

* Skip context_free operators

* Formatting

* Fix unit test

* Formatting
parent ca17bcd6
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/op/load.hpp> #include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraphx { namespace migraphx {
...@@ -16,9 +17,14 @@ void eliminate_concat::apply(program& p) const ...@@ -16,9 +17,14 @@ void eliminate_concat::apply(program& p) const
// Look for the concat operator // Look for the concat operator
if(ins->name() != concat_opt.name()) if(ins->name() != concat_opt.name())
continue; continue;
// If any inputs are literals then abort // If any inputs are builtin or context free then abort
if(std::any_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) { // If any inputs are used more than once, then abort since there could
return arg->name() == "@literal"; // be errors due to aliasing
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto arg) {
return arg->name().front() == '@' or
(arg->get_operator().is_context_free() and
not contains({"concat", "identity"}, arg->name())) or
arg->outputs().size() > 1;
})) }))
continue; continue;
// We can only do this optimization when concat axis is either the leftmost // We can only do this optimization when concat axis is either the leftmost
......
...@@ -59,10 +59,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -59,10 +59,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
adjust_allocation{}, adjust_allocation{},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
......
...@@ -849,11 +849,11 @@ struct test_conv_add : verify_program<test_conv_add> ...@@ -849,11 +849,11 @@ struct test_conv_add : verify_program<test_conv_add>
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto w = auto w = p.add_literal(
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}})); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = auto v = p.add_literal(
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}})); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
...@@ -868,11 +868,11 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -868,11 +868,11 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}});
auto w = auto w = p.add_literal(
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}})); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = auto v = p.add_literal(
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}})); migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
...@@ -881,6 +881,45 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -881,6 +881,45 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
} }
}; };
struct test_conv_bn_add : verify_program<test_conv_bn_add>
{
static migraphx::instruction_ref add_bn(migraphx::program& p,
migraphx::instruction_ref x,
std::size_t channels,
std::size_t seed = 1)
{
migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed)));
return p.add_instruction(
migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
}
migraphx::program create_program() const
{
migraphx::program p;
std::size_t ichannels = 64;
std::size_t ochannels = 256;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto w = p.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto v = p.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = p.add_instruction(migraphx::op::relu{}, x);
auto conv1 = p.add_instruction(migraphx::op::convolution{}, relu1, w);
auto bn1 = add_bn(p, conv1, ochannels, 1);
auto relu2 = p.add_instruction(migraphx::op::relu{}, y);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, relu2, v);
auto bn2 = add_bn(p, conv2, ochannels, 1);
auto sum = p.add_instruction(migraphx::op::add{}, bn1, bn2);
p.add_instruction(migraphx::op::relu{}, sum);
return p;
}
};
struct test_add_relu : verify_program<test_add_relu> struct test_add_relu : verify_program<test_add_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment