Unverified Commit dd6523c9 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Gather elements operator (#549)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* add parsing the operator GatherElements

* clang format

* add onnx unit tests for the gather_elments operator

* clang format

* clang format

* remove unnecessary files

* remove unnecessary files

* add a verify onnx unit test for the gather element operator

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1cc724ee
...@@ -34,7 +34,8 @@ struct reshape ...@@ -34,7 +34,8 @@ struct reshape
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(dims[i] == 0) if(dims[i] == 0)
...@@ -45,6 +46,7 @@ struct reshape ...@@ -45,6 +46,7 @@ struct reshape
if(dims[i] == -1) if(dims[i] == -1)
rdims[i] = 1; rdims[i] = 1;
} }
if(n_neg_dims > 0) if(n_neg_dims > 0)
{ {
size_t missing_dim = size_t missing_dim =
...@@ -59,15 +61,17 @@ struct reshape ...@@ -59,15 +61,17 @@ struct reshape
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -104,6 +104,7 @@ struct onnx_parser ...@@ -104,6 +104,7 @@ struct onnx_parser
add_mem_op("Expand", &onnx_parser::parse_expand); add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gather", &onnx_parser::parse_gather); add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling); add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling); add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
...@@ -909,6 +910,64 @@ struct onnx_parser ...@@ -909,6 +910,64 @@ struct onnx_parser
return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1])); return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
} }
instruction_ref
parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
// standardize input data and index
auto arg_data = make_contiguous(args[0]);
auto arg_ind = make_contiguous(args[1]);
auto data_s = arg_data->get_shape();
auto ind_s = arg_ind->get_shape();
if(data_s.lens().size() != ind_s.lens().size())
{
MIGRAPHX_THROW("PARSE_GATHER_ELEMENTS: input data and index must have the same rank!");
}
int n_rank = static_cast<int>(data_s.lens().size());
int tuned_axis = (axis < 0) ? (axis + n_rank) : axis;
auto axis_stride = data_s.strides()[tuned_axis];
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = prog.add_instruction(op::reshape{{data_elem_num}}, arg_data);
std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num);
std::iota(ind_index.begin(), ind_index.end(), 0);
// convert index in input indices to that in input data
std::vector<int> data_indices(elem_num);
std::transform(ind_index.begin(), ind_index.end(), data_indices.begin(), [&](auto i) {
return data_s.index(ind_s.multi(i));
});
std::vector<int> vec_axis_ind(elem_num);
std::transform(ind_index.begin(), ind_index.end(), vec_axis_ind.begin(), [&](auto i) {
return ind_s.multi(i)[tuned_axis];
});
auto l_shape_idx =
prog.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
auto dim_diff = prog.add_instruction(op::sub{}, arg_ind, l_dim_idx);
auto delta = prog.add_instruction(op::mul{}, dim_diff, l_stride);
auto ind = prog.add_instruction(op::add{}, l_shape_idx, delta);
op::gather op{0};
return prog.add_instruction(op, arg_data, ind);
}
instruction_ref instruction_ref
parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args) parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
......
gather_elements_axis1_test:•
/
data
indicesy"GatherElements*
axis gather_elements_axis1_testZ
data


Z
indices


b
y


B
\ No newline at end of file
...@@ -1131,6 +1131,38 @@ def gather_test(): ...@@ -1131,6 +1131,38 @@ def gather_test():
return ([node], [x, i], [y]) return ([node], [x, i], [y])
@onnx_test
def gather_elements_axis0_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4])
i = helper.make_tensor_value_info('indices', TensorProto.INT32, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node(
'GatherElements',
inputs=['data', 'indices'],
outputs=['y'],
axis=0,
)
return ([node], [x, i], [y])
@onnx_test
def gather_elements_axis1_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4])
i = helper.make_tensor_value_info('indices', TensorProto.INT32, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node(
'GatherElements',
inputs=['data', 'indices'],
outputs=['y'],
axis=1,
)
return ([node], [x, i], [y])
@onnx_test @onnx_test
def gemm_test(): def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7])
......
...@@ -777,6 +777,60 @@ TEST_CASE(gather_test) ...@@ -777,6 +777,60 @@ TEST_CASE(gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gather_elements_axis0_test)
{
migraphx::program p;
auto data = p.add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = p.add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
std::vector<int> ind_indices{0, 1, 2, 4, 5, 6};
std::vector<int> ind_axis_indices{0, 0, 0, 1, 1, 1};
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}};
auto l_data_indices =
p.add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices =
p.add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto rsp_data = p.add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = p.add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = p.add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = p.add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = p.add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = p.add_instruction(migraphx::op::gather{0}, rsp_data, ind);
p.add_return({ret});
auto prog = migraphx::parse_onnx("gather_elements_axis0_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_elements_axis1_test)
{
migraphx::program p;
auto data = p.add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = p.add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
std::vector<int> ind_indices{0, 1, 2, 4, 5, 6};
std::vector<int> ind_axis_indices{0, 1, 2, 0, 1, 2};
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}};
auto l_data_indices =
p.add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices =
p.add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto rsp_data = p.add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = p.add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = p.add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = p.add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = p.add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = p.add_instruction(migraphx::op::gather{0}, rsp_data, ind);
p.add_return({ret});
auto prog = migraphx::parse_onnx("gather_elements_axis1_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -39,4 +39,27 @@ TEST_CASE(instance_norm_test) ...@@ -39,4 +39,27 @@ TEST_CASE(instance_norm_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(gather_elements)
{
migraphx::program p = migraphx::parse_onnx("gather_elements_axis0_test.onnx");
p.compile(migraphx::cpu::target{});
migraphx::shape s_data{migraphx::shape::float_type, {3, 4}};
std::vector<float> data = {
0.25, 0.75, 0.9375, 0.4375, 0.6875, 0.5625, -0.875, 0.1875, -0.125, 0.5, -0.9375, -0.0625};
migraphx::shape s_ind{migraphx::shape::int32_type, {2, 3}};
std::vector<int> ind = {2, 1, 2, 0, 1, 0};
migraphx::program::parameter_map pp;
pp["data"] = migraphx::argument(s_data, data.data());
pp["indices"] = migraphx::argument(s_ind, ind.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment