Unverified Commit 41ba30d5 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

fix parse_instancenorm to create broadcast and multibroadcast instruc… (#1715)

* fix parse_instancenorm to create broadcast and multibroadcast instructions with two dynamic shape arguments instead of 1.  Their make_op() functions don't support dynamic shapes when called with one input.  This caused an error when parsing an ONNX 3duunet model

* Use add_common_op() to create multibroadcast op.

* add verification and parsing test for instance_norm with dynamic input.  Parse test doesn't pass.

* fix for test; still doesn't pass

* another fix for test; still doesn't pass

* work in progress, instance_norm_dyn_batch_test works but instance_norm_test doesn't

* fix onnx instancenorm tests to match parser changes.  Passes all check tests

* Updated comments explaining usage of add_common_op()

* hand-merged conflicts with develop

* fix instance_norm_half_test after merge

* add Onnx test instance_norm_dyn_batch_half_test

* add shape test cases broadcast_1in_dyn_error and multibroadcast_1in_dyn_error_0
parent 5bf067ed
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -148,6 +148,18 @@ shape common_shape(const std::vector<shape>& shapes) ...@@ -148,6 +148,18 @@ shape common_shape(const std::vector<shape>& shapes)
return {compute_common_types(shapes), compute_common_lens(shapes)}; return {compute_common_types(shapes), compute_common_lens(shapes)};
} }
/**
* @brief Creates and adds instructions to convert input arguments to common shapes and types
* by adding multi-broadcast and type convert operations. This is a utility function for creating
* operations where the shape and type of inputs need to match. It supports both dynamic and
* static-shaped arguments.
*
* @param m containing module for instruction
* @param ins insertion location in instruction list
* @param inputs instructions to use as argument list; also, the shapes
* attached to each instruction_ref are considered for broadcasting
* @return std::vector<instruction_ref> a modified argument list
*/
std::vector<instruction_ref> std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs) insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs)
{ {
...@@ -158,7 +170,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -158,7 +170,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
if(inputs.size() != 2) if(inputs.size() != 2)
{ {
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) + MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape"); " inputs. Requires exactly two inputs if any are dynamic shape");
} }
auto c_type = compute_common_types(to_shapes(inputs)); auto c_type = compute_common_types(to_shapes(inputs));
...@@ -224,6 +236,9 @@ instruction_ref insert_common_op(module& m, ...@@ -224,6 +236,9 @@ instruction_ref insert_common_op(module& m,
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs))); return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
} }
/**
* Wrapper for insert_common_args() which inserts operation at the end of the module.
*/
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs) instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
{ {
return insert_common_op(m, m.end(), op, std::move(inputs)); return insert_common_op(m, m.end(), op, std::move(inputs));
......
...@@ -68,6 +68,9 @@ struct broadcast ...@@ -68,6 +68,9 @@ struct broadcast
{ {
// the ONNX broadcast op is deprecated now, so not handling the negative // the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(s0.dynamic())
MIGRAPHX_THROW(
"BROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -37,8 +37,10 @@ namespace op { ...@@ -37,8 +37,10 @@ namespace op {
/** /**
* Broadcast multiple dimensions between two tensors. * Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs. * Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it. * One input version uses output_lens attribute and broadcasts to it (does not support
* Two inputs version broadcasts both inputs to the common shape at evaluation time. * dynamic shape input).
*
* Two inputs version broadcasts the first input to the common shape of the two inputs.
*/ */
struct multibroadcast struct multibroadcast
{ {
...@@ -81,6 +83,9 @@ struct multibroadcast ...@@ -81,6 +83,9 @@ struct multibroadcast
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
if(s0.dynamic())
MIGRAPHX_THROW(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(s0.lens().size() > output_lens.size()) if(s0.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
......
...@@ -183,7 +183,7 @@ struct shape ...@@ -183,7 +183,7 @@ struct shape
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*! /*!
* The number of dimensions in the shape. * The number of dimensions in the shape, either static or dynamic.
* Same as the number of indices required to get a data value. * Same as the number of indices required to get a data value.
*/ */
std::size_t ndim() const; std::size_t ndim() const;
......
...@@ -149,6 +149,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -149,6 +149,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return this->add_common_op(op_name, arg0, arg1); return this->add_common_op(op_name, arg0, arg1);
} }
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const std::vector<instruction_ref> inputs) const
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -84,16 +84,17 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -84,16 +84,17 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); bool dyn_input = x->get_shape().dynamic();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); // Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); // inputs may be either static or dynamic.
auto l1 = info.add_common_op("sub", x, mean);
// for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take // for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
// reduce_sum to calculate variance i.e. // reduce_sum to calculate variance i.e.
// var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n) // var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
...@@ -107,23 +108,32 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -107,23 +108,32 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
}); });
n = 1.0 / std::sqrt(n); n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}}); auto n_literal = info.add_literal(literal{dtype, {n}});
mean_bcast = info.add_common_op("mul", {mean_bcast, n_literal});
x = info.add_common_op("mul", {x, n_literal}); x = info.add_common_op("mul", {x, n_literal});
} }
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_common_op("sqdiff", x, mean);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}}); auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto epsilon_bcast = auto l2 = info.add_common_op("add", variance, epsilon_literal);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_common_op("mul", l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); // add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
auto bias_bcast = // both scale and bias.
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); instruction_ref scale_bcast;
instruction_ref bias_bcast;
if(dyn_input)
{
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
}
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast); auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16) if(dtype == shape::half_type and convert_fp16)
......
...@@ -3341,6 +3341,39 @@ def instance_norm_type_mismatch_test(): ...@@ -3341,6 +3341,39 @@ def instance_norm_type_mismatch_test():
return ([node], [x, scale, bias], [y]) return ([node], [x, scale, bias], [y])
@onnx_test()
def instance_norm_dyn_batch_test():
# the batch size is a dynamic dimension
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [None, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
return ([node], [x, scale, bias], [y])
@onnx_test()
def instance_norm_dyn_batch_half_test():
# the batch size is a dynamic dimension
x = helper.make_tensor_value_info('0', TensorProto.FLOAT16,
[None, 2, 3, 3])
scale = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [2])
bias = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [2])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT16,
[None, 2, 3, 3])
node = onnx.helper.make_node('InstanceNormalization',
inputs=['0', '1', '2'],
outputs=['3'])
return ([node], [x, scale, bias], [y])
@onnx_test() @onnx_test()
def instance_norm_invalid_type_test(): def instance_norm_invalid_type_test():
x = helper.make_tensor_value_info('0', TensorProto.INT32, [1, 2, 3, 3]) x = helper.make_tensor_value_info('0', TensorProto.INT32, [1, 2, 3, 3])
......
...@@ -3174,28 +3174,64 @@ TEST_CASE(instance_norm_test) ...@@ -3174,28 +3174,64 @@ TEST_CASE(instance_norm_test)
auto bias = mm->add_parameter("2", s2); auto bias = mm->add_parameter("2", s2);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = auto l1 = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); auto l0 = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
auto epsilon_literal = auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction( auto l2 = add_common_op(*mm, migraphx::make_op("add"), {variance, epsilon_literal});
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance); auto l4 = add_common_op(*mm, migraphx::make_op("mul"), {l1, l3});
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction( auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast); auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast); auto ret = mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
mm->add_return({ret});
migraphx::onnx_options options;
auto prog = migraphx::parse_onnx("instance_norm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(instance_norm_dyn_batch_test)
{
// instancenorm with dynamic input in the 0'th (batch) dimension
migraphx::shape s1{migraphx::shape::float_type, {{1, 2, {2}}, {2, 2}, {3, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {2}};
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("0", s1);
auto scale = mm->add_parameter("1", s2);
auto bias = mm->add_parameter("2", s2);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto l1 = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto l0 = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto l2 = add_common_op(*mm, migraphx::make_op("add"), {variance, epsilon_literal});
auto prog = optimize_onnx("instance_norm_test.onnx"); auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = add_common_op(*mm, migraphx::make_op("mul"), {l1, l3});
auto scale_bcast = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), scale, x);
auto bias_bcast = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), bias, x);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
auto ret = mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 2, {2}};
auto prog = migraphx::parse_onnx("instance_norm_dyn_batch_test.onnx", options);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -3212,6 +3248,7 @@ TEST_CASE(instance_norm_half_test) ...@@ -3212,6 +3248,7 @@ TEST_CASE(instance_norm_half_test)
auto scale_fp16 = mm->add_parameter("1", s2); auto scale_fp16 = mm->add_parameter("1", s2);
auto bias_fp16 = mm->add_parameter("2", s2); auto bias_fp16 = mm->add_parameter("2", s2);
// conversion of half type to float is enabled by default
auto x = mm->add_instruction( auto x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x_fp16); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x_fp16);
auto scale = mm->add_instruction( auto scale = mm->add_instruction(
...@@ -3220,20 +3257,19 @@ TEST_CASE(instance_norm_half_test) ...@@ -3220,20 +3257,19 @@ TEST_CASE(instance_norm_half_test)
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), bias_fp16); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), bias_fp16);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = auto l0 = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); auto l1 = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
// type of epsilon_literal is same as 0'th input; convert instruction will be added by
// add_common_op
auto epsilon_literal = auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction( auto l2 = add_common_op(*mm, migraphx::make_op("add"), {variance, epsilon_literal});
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance); auto l4 = add_common_op(*mm, migraphx::make_op("mul"), {l0, l3});
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction( auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
...@@ -3247,6 +3283,55 @@ TEST_CASE(instance_norm_half_test) ...@@ -3247,6 +3283,55 @@ TEST_CASE(instance_norm_half_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(instance_norm_dyn_batch_half_test)
{
// instancenorm with half type, dynamic input in the 0'th (batch) dimension
migraphx::shape s1{migraphx::shape::half_type, {{1, 2, {2}}, {2, 2}, {3, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::half_type, {2}};
migraphx::program p;
auto* mm = p.get_main_module();
auto x_fp16 = mm->add_parameter("0", s1);
auto scale_fp16 = mm->add_parameter("1", s2);
auto bias_fp16 = mm->add_parameter("2", s2);
// conversion of half type to float is enabled by default
auto x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x_fp16);
auto scale = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), scale_fp16);
auto bias = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), bias_fp16);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto l0 = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto l1 = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
// type of epsilon_literal is same as 0'th input; convert instruction will be added by
// add_common_op
auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto l2 = add_common_op(*mm, migraphx::make_op("add"), {variance, epsilon_literal});
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = add_common_op(*mm, migraphx::make_op("mul"), {l0, l3});
auto scale_bcast = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), scale, x);
auto bias_bcast = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), bias, x);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
auto instance_norm_fp32 = mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
auto ret = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}),
instance_norm_fp32);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 2, {2}};
auto prog = migraphx::parse_onnx("instance_norm_dyn_batch_half_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(instance_norm_type_mismatch_test) TEST_CASE(instance_norm_type_mismatch_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_type_mismatch_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("instance_norm_type_mismatch_test.onnx"); }));
......
...@@ -880,6 +880,48 @@ TEST_CASE(instance_norm_test) ...@@ -880,6 +880,48 @@ TEST_CASE(instance_norm_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(instance_norm_dyn_batch_test)
{
migraphx::program p = migraphx::parse_onnx("instance_norm_dyn_batch_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 3, 3}};
std::vector<float> data0 = {0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8};
migraphx::shape s1{migraphx::shape::float_type, {2}};
std::vector<float> data1 = {1, 2};
migraphx::shape s2{migraphx::shape::float_type, {2}};
std::vector<float> data2 = {0, 1};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(s0, data0.data());
pp["1"] = migraphx::argument(s1, data1.data());
pp["2"] = migraphx::argument(s2, data2.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.54919,
-1.16189,
-0.774596,
-0.387298,
0,
0.387298,
0.774596,
1.16189,
1.54919,
-2.09838,
-1.32379,
-0.549192,
0.225404,
1,
1.7746,
2.54919,
3.32379,
4.09838};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(instance_norm_3d_test) TEST_CASE(instance_norm_3d_test)
{ {
migraphx::program p = migraphx::parse_onnx("instance_norm_val_3d_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_3d_test.onnx");
......
...@@ -187,6 +187,14 @@ TEST_CASE(broadcast_axis_out_of_range_error) ...@@ -187,6 +187,14 @@ TEST_CASE(broadcast_axis_out_of_range_error)
throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input); throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input);
} }
TEST_CASE(broadcast_1in_dyn_error)
{
// broadcast doesn't support single dynamic shape input
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input);
}
TEST_CASE(broadcast_2in_static_static) TEST_CASE(broadcast_2in_static_static)
{ {
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
...@@ -1434,6 +1442,14 @@ TEST_CASE(multibroadcast) ...@@ -1434,6 +1442,14 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(multibroadcast_1in_dyn_error_0)
{
// multibroadcast doesn't support single dynamic shape input
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}};
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
TEST_CASE(multibroadcast_2in_static_dyn0) TEST_CASE(multibroadcast_2in_static_dyn0)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment