Unverified Commit 8fa33f1a authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Improve parsing slice (#467)



* refine slice implementation

* clang format

* fix cppcheck error

* clang format

* fix a bug in parsing slice

* add unit tests for slice operator

* clang format

* fix cppcheck error

* add the numpy package in the Dockerfile

* add missing onnx file

* fix cppcheck error

* fix cppcheck error

* add one more unit test related to slice and split

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 0325c1a4
...@@ -52,7 +52,7 @@ ENV LC_ALL=C.UTF-8 ...@@ -52,7 +52,7 @@ ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8 ENV LANG=C.UTF-8
# Install cget # Install cget
RUN pip3 install cget RUN pip3 install cget && pip3 install numpy
# Install rclone # Install rclone
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -30,6 +28,54 @@ struct slice ...@@ -30,6 +28,54 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
void tune_attributes(std::vector<int64_t>& tuned_axes,
std::vector<int64_t>& tuned_starts,
std::vector<int64_t>& tuned_ends,
const std::vector<std::size_t>& lens) const
{
// tune axes
int64_t n_rank = static_cast<int64_t>(lens.size());
if(!std::all_of(tuned_axes.begin(), tuned_axes.end(), [=](auto i) {
return (i < n_rank and i >= -n_rank);
}))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(tuned_axes) + " out of range");
}
std::transform(tuned_axes.begin(), tuned_axes.end(), tuned_axes.begin(), [=](auto i) {
return (i < 0) ? (i + n_rank) : i;
});
std::vector<int64_t> axis_lens(tuned_axes.size());
std::transform(tuned_axes.begin(), tuned_axes.end(), axis_lens.begin(), [&](auto axis) {
return lens[axis];
});
// tune starts
std::transform(tuned_starts.begin(),
tuned_starts.end(),
axis_lens.begin(),
tuned_starts.begin(),
[=](auto i, auto dim) {
i = (i < -dim) ? -dim : ((i > dim) ? dim : i);
return (i < 0) ? (i + dim) : i;
});
// tune ends
std::transform(tuned_ends.begin(),
tuned_ends.end(),
axis_lens.begin(),
tuned_ends.begin(),
[=](auto i, auto dim) {
i = (i < -dim) ? -dim : ((i > dim) ? dim : i);
return (i < 0) ? (i + dim) : i;
});
if(!(tuned_ends >= tuned_starts))
{
MIGRAPHX_THROW("SLICE: starts and ends does not match");
}
}
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{ {
int64_t r = std::min(index, static_cast<int64_t>(lens[axis])); int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
...@@ -40,22 +86,27 @@ struct slice ...@@ -40,22 +86,27 @@ struct slice
auto compute_offset(const shape& s) const auto compute_offset(const shape& s) const
{ {
std::vector<int64_t> tuned_axes = axes;
std::vector<int64_t> tuned_starts = starts;
std::vector<int64_t> tuned_ends = ends;
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
tune_attributes(tuned_axes, tuned_starts, tuned_ends, lens);
const std::vector<std::size_t>& strides = s.strides(); const std::vector<std::size_t>& strides = s.strides();
auto offset = 0; auto offset = 0;
if(!axes.empty()) if(!tuned_axes.empty())
{ {
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < tuned_axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = tuned_axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis]; offset += fix_index(lens, axis, tuned_starts[i]) * strides[axis];
} }
} }
else else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
offset += fix_index(lens, axis, starts[axis]) * strides[axis]; offset += fix_index(lens, axis, tuned_starts[axis]) * strides[axis];
} }
} }
return offset; return offset;
...@@ -69,14 +120,19 @@ struct slice ...@@ -69,14 +120,19 @@ struct slice
const auto& old_strides = input_shape.strides(); const auto& old_strides = input_shape.strides();
if(starts.size() != axes.size() || axes.size() != ends.size()) if(starts.size() != axes.size() || axes.size() != ends.size())
{ {
MIGRAPHX_THROW("inconsistent sizes"); MIGRAPHX_THROW("SLICE: inconsistent sizes");
} }
std::vector<int64_t> tuned_axes = axes;
std::vector<int64_t> tuned_starts = starts;
std::vector<int64_t> tuned_ends = ends;
tune_attributes(tuned_axes, tuned_starts, tuned_ends, old_lens);
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < tuned_axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = tuned_axes[i];
new_lens[axis] = new_lens[axis] = fix_index(old_lens, axis, tuned_ends[i]) -
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]); fix_index(old_lens, axis, tuned_starts[i]);
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
......
...@@ -776,28 +776,56 @@ struct onnx_parser ...@@ -776,28 +776,56 @@ struct onnx_parser
parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args) parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size(); // slice can have up to 5 inputs, we first check the 5th one
if(contains(info.attributes, "axes")) // to decide whether MIGRAPHX can handle this slice
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
std::vector<int> steps;
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
}
}
if(args.size() >= 4)
{
migraphx::argument axes_arg = args.at(3)->eval();
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "axes"))
{ {
literal s = parse_value(info.attributes.at("axes")); literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
} }
else
if(args.size() >= 3)
{ {
op.axes = std::vector<int64_t>(num_dims); migraphx::argument end_arg = args.at(2)->eval();
std::iota(op.axes.begin(), op.axes.end(), 0); check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "ends"))
if(contains(info.attributes, "ends"))
{ {
op.ends = get_indices(info.attributes.at("ends")); op.ends = get_indices(info.attributes.at("ends"));
} }
if(contains(info.attributes, "starts"))
if(args.size() >= 2)
{
migraphx::argument start_arg = args.at(1)->eval();
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "starts"))
{ {
literal s = parse_value(info.attributes.at("starts")); literal s = parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
......
...@@ -1783,6 +1783,59 @@ def sinh_test(): ...@@ -1783,6 +1783,59 @@ def sinh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def slice_5arg_test():
step = np.array([1, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-1, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-5, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test @onnx_test
def slice_test(): def slice_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2])
...@@ -1808,6 +1861,23 @@ def softmax_test(): ...@@ -1808,6 +1861,23 @@ def softmax_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def split_minus_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 5])
node = onnx.helper.make_node(
'Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=-1,
)
return ([node], [x], [y1, y2, y3])
@onnx_test @onnx_test
def split_test(): def split_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
...@@ -1376,6 +1376,22 @@ TEST_CASE(sinh_test) ...@@ -1376,6 +1376,22 @@ TEST_CASE(sinh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_5arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {1, 1}});
auto ret = p.add_instruction(migraphx::op::slice{{-1, -2}, {-5, -3}, {-1, -1}}, l0);
p.add_return({ret});
auto prog = migraphx::parse_onnx("slice_5arg_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_test) TEST_CASE(slice_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1396,6 +1412,20 @@ TEST_CASE(softmax_test) ...@@ -1396,6 +1412,20 @@ TEST_CASE(softmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(split_minus_axis_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{-1}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{-1}, {5}, {10}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{-1}, {10}, {15}}, input);
p.add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_minus_axis_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test) TEST_CASE(split_test)
{ {
migraphx::program p; migraphx::program p;
......
slice_5arg_test:
0arg_step"Constant*
value**Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_testZ
0


b
1


B
\ No newline at end of file
split_minus_axis_test:
,
xy1y2y3"Split*
axissplit_minus_axis_testZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment