Unverified Commit 69925294 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Instance norm kdims (#620)



* fix parsing to kdims

* add 5d size

* fix assert

* add 3d test

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9dabe26b
...@@ -1338,8 +1338,8 @@ struct onnx_parser ...@@ -1338,8 +1338,8 @@ struct onnx_parser
parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args) parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({H, W}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({H, W}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
...@@ -1350,11 +1350,17 @@ struct onnx_parser ...@@ -1350,11 +1350,17 @@ struct onnx_parser
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto ndims = dims.size();
assert(ndims >= 2);
auto kdims = ndims - 2;
auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), x); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
auto mean = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean); auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = prog.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", {2, 3}}}), l0); auto variance = prog.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = prog.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon); auto epsilon_literal = prog.add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal); auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
......
...@@ -34,6 +34,11 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -34,6 +34,11 @@ constexpr void visit_tensor_size(index_int n, F f)
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5:
{
f(std::integral_constant<index_int, 5>{});
break;
}
default: throw std::runtime_error("Tensor dims " + std::to_string(n) + " out of range"); default: throw std::runtime_error("Tensor dims " + std::to_string(n) + " out of range");
} }
} }
......
...@@ -1514,6 +1514,36 @@ def instance_norm_val_test(): ...@@ -1514,6 +1514,36 @@ def instance_norm_val_test():
return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor]) return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor])
@onnx_test
def instance_norm_val_3d_test():
x = np.array([[[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]]]])
scale = np.array([1, 2])
bias = np.array([0, 1])
x_tensor = helper.make_tensor(name='x_tensor',
data_type=TensorProto.FLOAT,
dims=x.shape,
vals=x.flatten().astype(np.float))
scale_tensor = helper.make_tensor(name='scale_tensor',
data_type=TensorProto.FLOAT,
dims=scale.shape,
vals=scale.flatten().astype(np.float))
bias_tensor = helper.make_tensor(name='bias_tensor',
data_type=TensorProto.FLOAT,
dims=bias.shape,
vals=bias.flatten().astype(np.float))
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 2, 2, 2])
node = onnx.helper.make_node(
'InstanceNormalization',
inputs=['x_tensor', 'scale_tensor', 'bias_tensor'],
outputs=['y'])
return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor])
@onnx_test @onnx_test
def layernorm_test(): def layernorm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 1, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 1, 5])
......
...@@ -39,6 +39,35 @@ TEST_CASE(instance_norm_test) ...@@ -39,6 +39,35 @@ TEST_CASE(instance_norm_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(instance_norm_3d_test)
{
migraphx::program p = migraphx::parse_onnx("instance_norm_val_3d_test.onnx");
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(16);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.52752,
-1.09109,
-0.654653,
-0.218218,
0.218218,
0.654653,
1.09109,
1.52752,
-2.05505,
-1.18218,
-0.309306,
0.563565,
1.43644,
2.30931,
3.18218,
4.05505};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(averagepool_notset_test) TEST_CASE(averagepool_notset_test)
{ {
auto p = migraphx::parse_onnx("averagepool_notset_test.onnx"); auto p = migraphx::parse_onnx("averagepool_notset_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment