Commit e08b425f authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into refactor_dynamic_compute

parents fbe13c96 5fa42993
...@@ -27,22 +27,40 @@ ...@@ -27,22 +27,40 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
static migraphx::instruction_ref static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
{ {
auto* mm = p.get_main_module(); auto bn_lens = x->get_shape().lens();
migraphx::shape vars{migraphx::shape::float_type, {channels}}; auto c_len = bn_lens.at(1);
auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + channels)));
auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + channels))); migraphx::shape vars{migraphx::shape::float_type, {c_len}};
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + channels))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto variance = auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + channels))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
return mm->add_instruction( auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance);
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias);
auto usq_mean = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), variance);
auto numer = add_common_op(m, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(m, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(m, migraphx::make_op("pow"), {var_eps, rt});
auto div0 = add_common_op(m, migraphx::make_op("div"), {numer, denom});
auto r0 = add_common_op(m, migraphx::make_op("mul"), {div0, usq_scale});
return add_common_op(m, migraphx::make_op("add"), {r0, usq_bias});
} }
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
...@@ -59,7 +77,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -59,7 +77,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x1, x1,
w1); w1);
auto bn1 = add_bn(p, conv1, 2048); auto bn1 = add_bn(*mm, conv1);
auto x2 = mm->add_parameter("x2", xs2); auto x2 = mm->add_parameter("x2", xs2);
auto w2 = mm->add_parameter("w2", ws2); auto w2 = mm->add_parameter("w2", ws2);
auto conv2 = mm->add_instruction( auto conv2 = mm->add_instruction(
...@@ -67,7 +85,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -67,7 +85,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x2, x2,
w2); w2);
auto bn2 = add_bn(p, conv2, 2048); auto bn2 = add_bn(*mm, conv2);
auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2); auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
auto relu = mm->add_instruction(migraphx::make_op("relu"), add); auto relu = mm->add_instruction(migraphx::make_op("relu"), add);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment