Unverified Commit 4918d769 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Fix literal type in the instance_norm parsing (#1317)

instancenorm parser always creates literal of type float which would fail in type check while creating binary ops if model is fp16.
parent 77e80b8e
...@@ -32,9 +32,12 @@ namespace onnx { ...@@ -32,9 +32,12 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); auto ndims = dims.size();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
...@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon); auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
auto epsilon_bcast = auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
......
...@@ -2496,6 +2496,62 @@ def instance_norm_test(): ...@@ -2496,6 +2496,62 @@ def instance_norm_test():
return ([node], [x, scale, bias], [y]) return ([node], [x, scale, bias], [y])
@onnx_test
def instance_norm_half_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test
def instance_norm_type_mismatch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test
def instance_norm_invalid_type_test():
x = helper.make_tensor_value_info('0', TensorProto.INT32, [1, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test
def instance_norm_nonbroadcastable_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [4])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test @onnx_test
def instance_norm_val_test(): def instance_norm_val_test():
x = np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], x = np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
......
instance_norm_half_test:
#
0
1
23"InstanceNormalizationinstance_norm_half_testZ
0





Z
1


Z
2


b
3





B
\ No newline at end of file
instance_norm_invalid_type_test:
#
0
1
23"InstanceNormalizationinstance_norm_invalid_type_testZ
0




Z
1

Z
2

b
3




B
\ No newline at end of file
#instance_norm_nonbroadcastable_test:
#
0
1
23"InstanceNormalization#instance_norm_nonbroadcastable_testZ
0




Z
1

Z
2

b
3




B
\ No newline at end of file
 instance_norm_type_mismatch_test:
#
0
1
23"InstanceNormalization instance_norm_type_mismatch_testZ
0




Z
1


Z
2


b
3




B
\ No newline at end of file
...@@ -2370,7 +2370,8 @@ TEST_CASE(instance_norm_test) ...@@ -2370,7 +2370,8 @@ TEST_CASE(instance_norm_test)
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast); auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(1e-5f); auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction( auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
...@@ -2390,6 +2391,60 @@ TEST_CASE(instance_norm_test) ...@@ -2390,6 +2391,60 @@ TEST_CASE(instance_norm_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(instance_norm_half_test)
{
std::vector<size_t> dims{1, 2, 3, 3};
migraphx::shape s1{migraphx::shape::half_type, dims};
migraphx::shape s2{migraphx::shape::half_type, {2}};
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("0", s1);
auto scale = mm->add_parameter("1", s2);
auto bias = mm->add_parameter("2", s2);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
auto prog = optimize_onnx("instance_norm_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(instance_norm_type_mismatch_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_type_mismatch_test.onnx"); }));
}
TEST_CASE(instance_norm_invalid_type_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_invalid_type_test.onnx"); }));
}
TEST_CASE(instance_norm_nonbroadcastable_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_nonbroadcastable_test.onnx"); }));
}
TEST_CASE(leaky_relu_test) TEST_CASE(leaky_relu_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment