Unverified Commit a4957ab2 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Update Constant parsing to support more attributes #2141 (#2216)

Add parsing support for value_float, value_floats, value_int, value_ints attributes
Disable failing tests
Resolves Test failures due to IndexError: _Map_base::at migraphx-benchmark/AMDMIGraphX#76
parent ac310ae5
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
constant_no_attributes_test:)
"Constantconstant_no_attributes_testB
\ No newline at end of file
constant_value_int_test:7
"Constant*
value_int@ constant_value_int_testB
\ No newline at end of file
constant_value_ints_test:=
!"Constant*
value_ints@@@ constant_value_ints_testB
\ No newline at end of file
...@@ -825,6 +825,76 @@ def constant_test(): ...@@ -825,6 +825,76 @@ def constant_test():
return ([node], [], [y]) return ([node], [], [y])
@onnx_test()
def constant_value_float_test():
node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_float=[1.0])
return ([node], [], [])
@onnx_test()
def constant_value_floats_test():
node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_floats=[1.0, 2.0, 3.0])
return ([node], [], [])
@onnx_test()
def constant_value_int_test():
node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_int=[1])
return ([node], [], [])
@onnx_test()
def constant_value_ints_test():
node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_ints=[1, 2, 3])
return ([node], [], [])
@onnx_test()
def constant_no_attributes_test():
node = onnx.helper.make_node('Constant', inputs=[], outputs=[])
return ([node], [], [])
@onnx_test()
def constant_multiple_attributes_test():
x = np.array([0, 1, 2])
node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_floats=[1.0, 2.0],
value_ints=[1, 2],
value=onnx.helper.make_tensor(
name='const_tensor',
data_type=TensorProto.FLOAT,
dims=x.shape,
vals=x.flatten().astype(float)))
return ([node], [], [])
@onnx_test() @onnx_test()
def constant_fill_test(): def constant_fill_test():
value = helper.make_tensor_value_info('value', TensorProto.FLOAT, [2, 3]) value = helper.make_tensor_value_info('value', TensorProto.FLOAT, [2, 3])
......
...@@ -930,6 +930,58 @@ TEST_CASE(constant_test) ...@@ -930,6 +930,58 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_value_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1.0f}});
auto prog = optimize_onnx("constant_value_float_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_value_floats_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {1.0f, 2.0f, 3.0f}});
auto prog = optimize_onnx("constant_value_floats_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_value_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {1}}, {1}});
auto prog = optimize_onnx("constant_value_int_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_value_ints_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {3}}, {1, 2, 3}});
auto prog = optimize_onnx("constant_value_ints_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_no_attributes_test)
{
EXPECT(test::throws([&] { optimize_onnx("constant_no_attributes_test.onnx"); }));
}
TEST_CASE(constant_multiple_attributes_test)
{
EXPECT(test::throws([&] { optimize_onnx("constant_multiple_attributes_test.onnx"); }));
}
TEST_CASE(constant_fill_test) TEST_CASE(constant_fill_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -108,6 +108,34 @@ def disabled_tests_onnx_1_12_0(backend_test): ...@@ -108,6 +108,34 @@ def disabled_tests_onnx_1_12_0(backend_test):
backend_test.exclude(r'test_scatter_elements_with_duplicate_indices_cpu') backend_test.exclude(r'test_scatter_elements_with_duplicate_indices_cpu')
def disabled_tests_onnx_1_13_0(backend_test):
# The following tests fail due to the CastLike operator being unsupported
backend_test.exclude(r'test_elu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_elu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_elu_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_default_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_example_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_expanded_ver18_cpu')
backend_test.exclude(r'test_leakyrelu_default_expanded_cpu')
backend_test.exclude(r'test_leakyrelu_example_expanded_cpu')
backend_test.exclude(r'test_leakyrelu_expanded_cpu')
backend_test.exclude(r'test_selu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_selu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_selu_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_expanded_ver18_cpu')
backend_test.exclude(r'test_relu_expanded_ver18_cpu')
backend_test.exclude(r'test_softsign_example_expanded_ver18_cpu')
backend_test.exclude(r'test_softsign_expanded_ver18_cpu')
def disabled_tests_onnx_1_14_0(backend_test):
# The following tests fail due to the CastLike operator being unsupported
backend_test.exclude(r'test_softplus_example_expanded_ver18_cpu')
backend_test.exclude(r'test_softplus_expanded_ver18_cpu')
def create_backend_test(testname=None, target_device=None): def create_backend_test(testname=None, target_device=None):
if target_device is not None: if target_device is not None:
c2.set_device(target_device) c2.set_device(target_device)
...@@ -334,6 +362,12 @@ def create_backend_test(testname=None, target_device=None): ...@@ -334,6 +362,12 @@ def create_backend_test(testname=None, target_device=None):
if version.parse(onnx.__version__) >= version.parse("1.12.0"): if version.parse(onnx.__version__) >= version.parse("1.12.0"):
disabled_tests_onnx_1_12_0(backend_test) disabled_tests_onnx_1_12_0(backend_test)
if version.parse(onnx.__version__) >= version.parse("1.13.0"):
disabled_tests_onnx_1_13_0(backend_test)
if version.parse(onnx.__version__) >= version.parse("1.14.0"):
disabled_tests_onnx_1_14_0(backend_test)
# import all test cases at global scope to make # import all test cases at global scope to make
# them visible to python.unittest. # them visible to python.unittest.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment