Unverified Commit f16215d6 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Nonzero op (#594)



* add parsing NonZero operator

* clang format

* add unit test

* clang format

* add unit tests for more code coverage

* clang format

* fix an error message

* fix review comments

* fix review comments
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent bd974b2b
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -119,6 +121,7 @@ struct onnx_parser ...@@ -119,6 +121,7 @@ struct onnx_parser
add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>); add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>); add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("NonZero", &onnx_parser::parse_nonzero);
add_mem_op("OneHot", &onnx_parser::parse_onehot); add_mem_op("OneHot", &onnx_parser::parse_onehot);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("Range", &onnx_parser::parse_range); add_mem_op("Range", &onnx_parser::parse_range);
...@@ -2320,6 +2323,49 @@ struct onnx_parser ...@@ -2320,6 +2323,49 @@ struct onnx_parser
MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator"); MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
} }
template <class T>
std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
{
std::vector<std::size_t> indices;
for(std::size_t i = 0; i < data.size(); ++i)
{
if(!float_equal(data[i], 0))
indices.push_back(i);
}
return indices;
}
instruction_ref
parse_nonzero(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
migraphx::argument data_arg = args.back()->eval();
check_arg_empty(data_arg, "PARSE_NONZERO: cannot support non-constant input!");
std::vector<std::size_t> indices;
data_arg.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data;
vec_data.assign(val.begin(), val.end());
indices = this->nonzero_indices(vec_data);
});
shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};
std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i)
{
auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
{
out_data[out_s.index({j, i})] = idx[j];
}
}
return prog.add_literal(literal(out_s, out_data));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -2497,7 +2543,7 @@ struct onnx_parser ...@@ -2497,7 +2543,7 @@ struct onnx_parser
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, s.data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT8: case onnx::TensorProto::UINT8:
......
...@@ -1722,6 +1722,38 @@ def no_pad_test(): ...@@ -1722,6 +1722,38 @@ def no_pad_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def nonzero_test():
data1 = np.array([[1., 0.], [1., 1.]])
data = helper.make_tensor(name='data',
data_type=TensorProto.FLOAT,
dims=data1.shape,
vals=data1.flatten().astype(np.float))
y = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 3])
node = onnx.helper.make_node('NonZero',
inputs=['data'],
outputs=['indices'])
return ([node], [], [y], [data])
@onnx_test
def nonzero_int_test():
data1 = np.array([[1, 1, 0], [1, 0, 1]])
data = helper.make_tensor(name='data',
data_type=TensorProto.INT16,
dims=data1.shape,
vals=data1.flatten().astype(np.int16))
y = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 4])
node = onnx.helper.make_node('NonZero',
inputs=['data'],
outputs=['indices'])
return ([node], [], [y], [data])
@onnx_test @onnx_test
def onehot_test(): def onehot_test():
axis_value = 0 axis_value = 0
......
...@@ -1348,6 +1348,38 @@ TEST_CASE(neg_test) ...@@ -1348,6 +1348,38 @@ TEST_CASE(neg_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(nonzero_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> data = {1, 0, 1, 1};
p.add_literal(migraphx::literal(s, data));
migraphx::shape si{migraphx::shape::int64_type, {2, 3}};
std::vector<int64_t> indices = {0, 1, 1, 0, 0, 1};
auto r = p.add_literal(migraphx::literal(si, indices));
p.add_return({r});
auto prog = migraphx::parse_onnx("nonzero_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(nonzero_int_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {2, 3}};
std::vector<int> data = {1, 1, 0, 1, 0, 1};
p.add_literal(migraphx::literal(s, data.begin(), data.end()));
migraphx::shape si{migraphx::shape::int64_type, {2, 4}};
std::vector<int64_t> indices = {0, 0, 1, 1, 0, 1, 0, 2};
auto r = p.add_literal(migraphx::literal(si, indices));
p.add_return({r});
auto prog = migraphx::parse_onnx("nonzero_int_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(onehot_test) TEST_CASE(onehot_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment