Unverified Commit 372206b3 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into fix_parse_if

parents fdc182c8 01d0ecfc
...@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"); MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
} }
if(x_lens.size() == 1) auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto n0 = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto d0 = info.add_broadcastable_binary_op("add", args[4], eps); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto d1 = info.add_broadcastable_binary_op("pow", d0, rt); auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", n0, d1); auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]); auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_lens.size() > 2) else if(x_rank > 2)
{ {
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
...@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
} }
else else
{ {
// num dims either 0 or 2 // rank == 0
MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) + MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
" input tensor, unhandled data format"); " input tensor, unhandled data format");
} }
......
batch_norm_invalid_rank_test: batch_norm_rank_2_test:
7 J
x x
scale scale
bias bias
mean mean
variancey"BatchNormalizationbatch_norm_invalid_rank_testZ variancey"BatchNormalization*
epsilon75batch_norm_rank_2_testZ
x x
 
 
Z Z
scale scale
 
Z Z
bias bias
 
Z Z
mean mean
 
Z Z
variance variance
 
b b
y y
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -331,6 +331,24 @@ def batch_norm_flat_test(): ...@@ -331,6 +331,24 @@ def batch_norm_flat_test():
return ([node], [x, scale, bias, mean, var], [out]) return ([node], [x, scale, bias, mean, var], [out])
@onnx_test
def batch_norm_rank_2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [5])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [5])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [5])
var = helper.make_tensor_value_info('variance', TensorProto.FLOAT, [5])
out = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5])
node = onnx.helper.make_node(
'BatchNormalization',
inputs=['x', 'scale', 'bias', 'mean', 'variance'],
outputs=['y'],
epsilon=1e-6)
return ([node], [x, scale, bias, mean, var], [out])
@onnx_test @onnx_test
def batch_norm_1d_test(): def batch_norm_1d_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 3, 4])
...@@ -385,23 +403,6 @@ def batch_norm_3d_test(): ...@@ -385,23 +403,6 @@ def batch_norm_3d_test():
return ([node], [x, scale, bias, mean, var], [out]) return ([node], [x, scale, bias, mean, var], [out])
@onnx_test
def batch_norm_invalid_rank_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [8, 8])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [8])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [8])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [8])
var = helper.make_tensor_value_info('variance', TensorProto.FLOAT, [8])
out = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8, 8])
node = onnx.helper.make_node(
'BatchNormalization',
inputs=['x', 'scale', 'bias', 'mean', 'variance'],
outputs=['y'])
return ([node], [x, scale, bias, mean, var], [out])
@onnx_test @onnx_test
def batch_norm_invalid_bias_rank_test(): def batch_norm_invalid_bias_rank_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4, 4])
......
...@@ -394,6 +394,31 @@ TEST_CASE(batch_norm_flat_test) ...@@ -394,6 +394,31 @@ TEST_CASE(batch_norm_flat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(batch_norm_rank_2_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 5}});
auto scale = mm->add_parameter("scale", {migraphx::shape::float_type, {5}});
auto bias = mm->add_parameter("bias", {migraphx::shape::float_type, {5}});
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {5}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {5}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt});
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias});
auto prog = optimize_onnx("batch_norm_rank_2_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(batch_norm_1d_test) TEST_CASE(batch_norm_1d_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -115,6 +115,43 @@ TEST_CASE(batch_norm_flat_test) ...@@ -115,6 +115,43 @@ TEST_CASE(batch_norm_flat_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(batch_norm_rank_2_test)
{
migraphx::program p = migraphx::parse_onnx("batch_norm_rank_2_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape x_shape{migraphx::shape::float_type, {2, 5}};
migraphx::shape c_shape(migraphx::shape::float_type, {5});
std::vector<float> x_data = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
std::vector<float> scale_data(5, 1.);
std::vector<float> bias_data(5, 0.);
std::vector<float> mean_data = {1., 2., 1., 2., 1.};
std::vector<float> variance_data(5, 0.5);
migraphx::parameter_map params;
params["x"] = migraphx::argument(x_shape, x_data.data());
params["scale"] = migraphx::argument(c_shape, scale_data.data());
params["bias"] = migraphx::argument(c_shape, bias_data.data());
params["mean"] = migraphx::argument(c_shape, mean_data.data());
params["variance"] = migraphx::argument(c_shape, variance_data.data());
auto result = p.eval(params).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.,
0.,
2.8284243,
2.8284243,
5.65684859,
7.07106074,
7.07106074,
9.89948504,
9.89948504,
12.72790933};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(batch_norm_1d_test) TEST_CASE(batch_norm_1d_test)
{ {
migraphx::program p = migraphx::parse_onnx("batch_norm_1d_test.onnx"); migraphx::program p = migraphx::parse_onnx("batch_norm_1d_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment