Unverified Commit bb0e04ce authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Imply type of literal returned based on input protobuff for zero elem… (#1326)

* Imply type of literal returned based on input protobuff for zero element constant values.

This saves us the default behavior as the onnx parsing assumes that every zero value is float. This way we're still grabbing relevant type information from the protobuff instead and wont fail our data type checks for if them/else blocks from onnx

* Revert "Imply type of literal returned based on input protobuff for zero element constant values."

This reverts commit 390bb853

.

* Add  test case to parse in empty constant int64 proto buffer

I think the previous test case was aliasing an issue where we default to float but need to actually read in int64 instead of int32

* fixup! Add  test case to parse in empty constant int64 proto buffer

* Add test for non empty int64 scalar

Add one item in the np array to use for the constant we're parsing in.

* Draft partial fix

* Fix test failures from previous change to read in protobuf data types correctly for empty constants.

Instead of assuming things are empty and thus we default to float, reading in the correct types broke some assumptions code was using for an empty literal.

* Fix formatting and naming

* Fix naming with var in constant_one_val_int64_test
Co-authored-by: default avatarcharlie <charlie.lin@amd.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent 67f77ac1
...@@ -45,6 +45,11 @@ struct literal : raw_data<literal> ...@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{ {
literal() {} literal() {}
/*!
* Empty literal with a specific shape type
*/
explicit literal(shape::type_t shape_type) : m_shape(shape_type, {}) {}
template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}> template <class U, class T = deduce<U>, shape::type_t ShapeType = shape::get_type<T>{}>
literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType) literal(U x) : buffer(make_shared_array<char>(sizeof(T))), m_shape(ShapeType)
{ {
......
...@@ -59,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const ...@@ -59,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
...@@ -76,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t ...@@ -76,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// scalar input // scalar input
......
...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant> ...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
......
...@@ -626,6 +626,46 @@ def constant_scalar_test(): ...@@ -626,6 +626,46 @@ def constant_scalar_test():
return ([node], [], [y]) return ([node], [], [y])
@onnx_test
def constant_empty_scalar_int64_test():
x = np.array([]).astype(np.int64)
y = helper.make_tensor_value_info('0', TensorProto.INT64, [0])
node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['0'],
value=onnx.helper.make_tensor(
name='one_element_tensor',
data_type=TensorProto.INT64,
dims=x.shape,
vals=x.flatten().astype(np.int64),
),
)
return ([node], [], [y])
@onnx_test
def constant_one_val_int64_test():
x = np.array([1]).astype(np.int64)
y = helper.make_tensor_value_info('0', TensorProto.INT64, [0])
node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['0'],
value=onnx.helper.make_tensor(
name='empty_tensor',
data_type=TensorProto.INT64,
dims=x.shape,
vals=x.flatten().astype(np.int64),
),
)
return ([node], [], [y])
@onnx_test @onnx_test
def const_of_shape_empty_input_test(): def const_of_shape_empty_input_test():
tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1], tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1],
......
...@@ -636,11 +636,31 @@ TEST_CASE(constant_scalar_test) ...@@ -636,11 +636,31 @@ TEST_CASE(constant_scalar_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_empty_scalar_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_one_val_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {1}}, {1}});
auto prog = optimize_onnx("constant_one_val_int64_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_empty_input_test) TEST_CASE(const_of_shape_empty_input_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal()); mm->add_literal(migraphx::literal(migraphx::shape::int32_type));
migraphx::shape s(migraphx::shape::int64_type, {1}, {0}); migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec)); mm->add_literal(migraphx::literal(s, vec));
...@@ -4066,7 +4086,7 @@ TEST_CASE(reducesum_empty_axes_test) ...@@ -4066,7 +4086,7 @@ TEST_CASE(reducesum_empty_axes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal({}); mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), x); auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), x);
auto r = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2, 3}}}), l1); auto r = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2, 3}}}), l1);
...@@ -4081,7 +4101,7 @@ TEST_CASE(reducesum_noop_test) ...@@ -4081,7 +4101,7 @@ TEST_CASE(reducesum_noop_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal({}); mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_return({x}); mm->add_return({x});
auto prog = migraphx::parse_onnx("reducesum_noop_test.onnx"); auto prog = migraphx::parse_onnx("reducesum_noop_test.onnx");
...@@ -5291,7 +5311,7 @@ TEST_CASE(squeeze_empty_axes_test) ...@@ -5291,7 +5311,7 @@ TEST_CASE(squeeze_empty_axes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal({}); mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}}); auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0);
mm->add_return({l1}); mm->add_return({l1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment