Unverified Commit e4ef64f4 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Variable input slice (#2039)

Allows slice to work with variable starts, ends, and axes input
Outputs a dynamic shape even with a static shape data input when the starts and ends are variable
parent 7ed4954a
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,6 +43,36 @@ struct select_dependent_type ...@@ -42,6 +43,36 @@ struct select_dependent_type
template <class T, class... Ts> template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
/**
* Used to normalize variable input axes at model runtime.
* Example: the axes inputs of the slice operator.
*
* \param axes the axes to normalize
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix = "");
/**
* Used to normalize variable input axes at model runtime.
* Example: the starts and ends inputs of the slice operator.
*
* \param indices the indices to normalize
* \param axes which axes the indices apply over
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix = "");
MIGRAPHX_EXPORT MIGRAPHX_EXPORT
bool normalize_attributes(operation& op, const shape& input_shape); bool normalize_attributes(operation& op, const shape& input_shape);
......
...@@ -27,19 +27,34 @@ ...@@ -27,19 +27,34 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Slice operator that accepts variable axes, starts and ends.
*
* Attributes:
* axes: constant axes to slice over (optional)
* starts: constant slice starting indices (optional)
* ends: constant slice ending indices (optional)
*
* Parameters:
* data: the input tensor to slice (dynamic or static shape)
* input_starts: starting indicies of slice (optional, static shape)
* input_ends: ending indicies of slice (optional, static shape)
* input_axes: axes to slice over (optional, static shape)
*/
struct slice struct slice
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes{};
std::vector<int64_t> starts; std::vector<int64_t> starts{};
std::vector<int64_t> ends; std::vector<int64_t> ends{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -48,8 +63,8 @@ struct slice ...@@ -48,8 +63,8 @@ struct slice
} }
/** /**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in * Ensure that attribute vectors axes, starts, and ends are all the same size and values are
* limits. * within limits.
*/ */
value attributes() const value attributes() const
{ {
...@@ -70,100 +85,235 @@ struct slice ...@@ -70,100 +85,235 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto compute_offset(const shape& s) const /**
{ * Computes the slice output shape dimensions for given starts, ends,and axes.
const std::vector<std::size_t>& lens = s.lens(); * Templated to also handle tensor views.
const std::vector<std::size_t>& strides = s.strides(); * Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this
auto offset = 0; * object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid.
if(not axes.empty()) */
{ template <class A, class B>
for(std::size_t i = 0; i < axes.size(); i++) std::vector<std::size_t>
{ lens_calc(const std::vector<std::size_t>& lengths, A in_starts, A in_ends, B in_axes) const
auto axis = axes[i];
offset += starts[i] * strides[axis];
}
}
else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) auto new_lens = lengths;
for(std::size_t i = 0; i < in_axes.size(); ++i)
{ {
offset += starts[axis] * strides[axis]; auto axis = in_axes[i];
new_lens[axis] = in_ends[i] - in_starts[i];
} }
} return new_lens;
return offset;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 3, 4);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(inputs.size() == 1)
{
auto t = input_shape.type(); auto t = input_shape.type();
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed(); return not input_shape.dyn_dims()[axis].is_fixed();
})) }))
{ {
MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and optimals if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
// Doesn't handle optimals
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
old_lens = input_shape.max_lens(); return shape{t,
new_mins = input_shape.min_lens(); lens_calc(input_shape.min_lens(), starts, ends, axes),
lens_calc(input_shape.max_lens(), starts, ends, axes),
{}};
}
else
{
return shape{
t, lens_calc(input_shape.lens(), starts, ends, axes), input_shape.strides()};
}
}
else
{
// check that starts, ends, and optionally input_axes are all 1D, have the same
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 3)
{
if(inputs[1].lens().at(0) != axes.size())
{
MIGRAPHX_THROW("SLICE: inputs starts and ends do not have the same dimension "
"as the axes attribute");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
} }
else else
{ {
old_lens = input_shape.lens(); // if axes is an input, then all the output dimensions could be 0 to the max value
// For static shape (including during eval step after a dynamic input) the strides are std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
// indexed into the pre-slice array, so they are larger than the apparent size of the return shape::dynamic_dimension{0, dd.max};
// resulting shape. });
old_strides = input_shape.strides(); }
return shape{input_shape.type(), dds};
}
} }
std::vector<std::size_t> new_lens = old_lens; /**
* Calculates the starting offset for the sliced tensor.
* Used in compute when only data input and all other information are in the attributes.
*
* \param s static input shape
*/
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(not axes.empty())
{
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
size_t sliced_length = ends[i] - starts[i]; offset += starts[i] * strides[axis];
// A Numpy indexing convention: a slice size larger than the actual dimension }
// is legal and the "ends" value is clipped to the axis size }
new_lens[axis] = std::min(new_lens[axis], sliced_length); else
if(input_shape.dynamic()) {
for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
// TODO: when non-fixed shape slicing is allowed, this will be different than offset += starts[axis] * strides[axis];
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
} }
} }
if(input_shape.dynamic()) return offset * s.type_size();
}
/**
* Calculates the starting offset for the sliced tensor (for aliasing).
* Used when the starts and/or the axes are inputs.
*
* \param s static input shape
* \param input_starts starting indices of slice
* \param ax_vec axes to slice on
*/
template <class IndView, class Axes>
auto compute_offset(const shape& s, const IndView& input_starts, const Axes& ax_vec) const
{ {
return shape{t, new_mins, new_lens, {}}; auto ret = 0;
for(std::size_t i = 0; i < ax_vec.size(); ++i)
{
auto axis = ax_vec[i];
ret += input_starts[i] * s.strides().at(axis);
} }
else return ret * s.type_size();
}
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(const shape& input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends) const
{ {
return shape{t, new_lens, old_strides}; auto attrs = this->attributes().at("normalize_axes");
return {{"input_starts",
normalize_indices(input_starts,
this->axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
this->axes,
input_shape,
attrs.at("ends"),
"Slice variable input_ends")}};
} }
/**
* Three input version of the normalize_inputs.
* This one also checks that the input_axes are valid.
*/
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(shape input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends,
const std::vector<int64_t>& input_axes) const
{
auto attrs = this->attributes().at("normalize_axes");
auto norm_axes =
normalize_axes(input_axes, input_shape, attrs.at("axes"), "Slice variable input_axes");
return {{"input_starts",
normalize_indices(input_starts,
norm_axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
norm_axes,
input_shape,
attrs.at("ends"),
"Slice variable input ends")},
{"input_axes", norm_axes}};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto input_shape = input.get_shape();
auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size(); switch(args.size())
{
case 1: {
std::size_t offset = compute_offset(input_shape);
return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
case 3: {
shape calc_shape;
std::size_t offset = 0;
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) {
auto norm_inputs = normalize_inputs(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>());
offset = compute_offset(input_shape, norm_inputs.at("input_starts"), this->axes);
calc_shape = {input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("input_starts"),
norm_inputs.at("input_ends"),
this->axes),
input_shape.strides()};
});
return {calc_shape, [=] { return input.data() + offset; }};
}
case 4: {
shape calc_shape;
std::size_t offset = 0;
visit_all(args[1], args[2], args[3])(
[&](auto input_starts, auto input_ends, auto input_axes) {
auto norm_inputs = normalize_inputs(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
offset = compute_offset(
input_shape, norm_inputs.at("input_starts"), norm_inputs.at("input_axes"));
calc_shape = shape{input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("input_starts"),
norm_inputs.at("input_ends"),
norm_inputs.at("input_axes")),
input_shape.strides()};
});
return {calc_shape, [=] { return input.data() + offset; }};
}
default: {
// Should never get here; covering in case some code change occurs
MIGRAPHX_THROW("SLICE: invalid number of inputs");
}
}
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
Message m) Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
if(result.empty())
{
return result;
};
int64_t n_rank = input_shape.ndim(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
...@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape) ...@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
return tuned; return tuned;
} }
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(axes, {}, attr_val, input_shape, [&] { return prefix; });
}
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(indices, axes, attr_val, input_shape, [&] { return prefix; });
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -34,16 +34,63 @@ namespace onnx { ...@@ -34,16 +34,63 @@ namespace onnx {
struct parse_slice : op_parser<parse_slice> struct parse_slice : op_parser<parse_slice>
{ {
std::vector<op_desc> operators() const { return {{"Slice"}}; } std::vector<op_desc> operators() const { return {{"Slice"}}; }
struct slice_desc
{
op::slice op;
std::vector<instruction_ref> op_args;
std::vector<int64_t> steps;
std::vector<int64_t> raxes;
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
migraphx::argument arg_value = arg->eval();
if(arg_value.empty())
{
op_args.insert(op_args.begin(), arg);
}
else
{
arg_value.visit([&](auto s) { result.assign(s.begin(), s.end()); });
}
return result;
}
};
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser, const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto sd = construct_slice_desc(parser, info, args);
auto ins = info.add_instruction(sd.op, sd.op_args);
if(not sd.raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", sd.raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(sd.steps.begin(),
sd.steps.end(),
std::back_inserter(nsteps),
[](auto s) { return std::abs(s); });
return ins = info.add_instruction(
make_op("step", {{"axes", sd.op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
}
slice_desc construct_slice_desc(const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
op::slice op; slice_desc sd;
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice. // to decide whether MIGRAPHX can handle this slice.
...@@ -51,89 +98,73 @@ struct parse_slice : op_parser<parse_slice> ...@@ -51,89 +98,73 @@ struct parse_slice : op_parser<parse_slice>
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { sd.steps.assign(s.begin(), s.end()); });
} }
if(args.size() >= 4) if(args.size() >= 4)
{ {
migraphx::argument axes_arg = args.at(3)->eval(); sd.op.axes = sd.insert(args.at(3));
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "axes")) else if(contains(info.attributes, "axes"))
{ {
literal s = parser.parse_value(info.attributes.at("axes")); literal s = parser.parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.axes)); });
} }
if(args.size() >= 3) if(args.size() >= 3)
{ {
migraphx::argument end_arg = args.at(2)->eval(); sd.op.ends = sd.insert(args.at(2));
check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "ends")) else if(contains(info.attributes, "ends"))
{ {
literal s = parser.parse_value(info.attributes.at("ends")); literal s = parser.parse_value(info.attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.ends)); });
} }
if(args.size() >= 2) if(args.size() >= 2)
{ {
migraphx::argument start_arg = args.at(1)->eval(); sd.op.starts = sd.insert(args.at(1));
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "starts")) else if(contains(info.attributes, "starts"))
{ {
literal s = parser.parse_value(info.attributes.at("starts")); literal s = parser.parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.starts)); });
} }
// data input argument
sd.insert(args.at(0));
// If axes arg is not given, the default is all of them. // If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(sd.op.axes.empty() and sd.op_args.size() < 3)
{ {
std::vector<int64_t> axes(args[0]->get_shape().ndim()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; sd.op.axes = axes;
} }
std::vector<int64_t> raxes; if(not sd.steps.empty())
{
if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported");
if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
}
assert(steps.empty() or steps.size() == op.axes.size()); assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(sd.steps.size()))
{ {
if(steps[i] >= 0) if(sd.steps[i] >= 0)
continue; continue;
op.starts[i] += 1; sd.op.starts[i] += 1;
if(op.starts[i] == 0) if(sd.op.starts[i] == 0)
op.starts[i] = INT_MAX; sd.op.starts[i] = INT_MAX;
op.ends[i] += 1; sd.op.ends[i] += 1;
raxes.push_back(op.axes[i]); sd.raxes.push_back(sd.op.axes[i]);
std::swap(op.starts[i], op.ends[i]); std::swap(sd.op.starts[i], sd.op.ends[i]);
} }
return sd;
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
} }
}; };
......
...@@ -6746,6 +6746,92 @@ def slice_max_end_test(): ...@@ -6746,6 +6746,92 @@ def slice_max_end_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def slice_var_input_static0():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends'],
axes=[0, 1],
outputs=['output'])
return ([node], [data, starts, ends], [output])
@onnx_test()
def slice_var_input_static1():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends', 'axes'],
outputs=['output'])
return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_dyn0():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends'],
axes=[0, 1],
outputs=['output'])
return ([node], [data, starts, ends], [output])
@onnx_test()
def slice_var_input_dyn1():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends', 'axes'],
outputs=['output'])
return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_steps_error():
step = np.array([2, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.FLOAT, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.FLOAT, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.FLOAT, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['data', 'starts', 'ends', 'axes', 'arg_step'],
outputs=['output'])
return ([arg_step, node], [data, starts, ends, axes], [output])
@onnx_test() @onnx_test()
def softmax_test(): def softmax_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3])
......
...@@ -6426,6 +6426,74 @@ TEST_CASE(slice_max_end_test) ...@@ -6426,6 +6426,74 @@ TEST_CASE(slice_max_end_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_var_input_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 2}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}});
mm->add_instruction(migraphx::make_op("slice", {{"axes", {0, 1}}}), data, starts, ends);
auto prog = optimize_onnx("slice_var_input_static0.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 2}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int64_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int64_type, {2}});
auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {2}});
mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes);
auto prog = optimize_onnx("slice_var_input_static1.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_dyn0)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}});
auto ret =
mm->add_instruction(migraphx::make_op("slice", {{"axes", {0, 1}}}), data, starts, ends);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("slice_var_input_dyn0.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_dyn1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}});
auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int32_type, {2}});
auto ret = mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("slice_var_input_dyn1.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_steps_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); }));
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
......
 slice_var_input_static1:
)
data
starts
ends
axesoutput"Sliceslice_var_input_static1Z
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
 slice_var_input_steps_error:
0arg_step"Constant*
value**Bstep
3
data
starts
ends
axes
arg_stepoutput"Sliceslice_var_input_steps_errorZ
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
...@@ -2822,7 +2822,7 @@ TEST_CASE(select_module_dyn) ...@@ -2822,7 +2822,7 @@ TEST_CASE(select_module_dyn)
input); input);
} }
TEST_CASE(slice_shape) TEST_CASE(slice_static_shape)
{ {
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
...@@ -2840,6 +2840,67 @@ TEST_CASE(slice_shape) ...@@ -2840,6 +2840,67 @@ TEST_CASE(slice_shape)
input); input);
} }
TEST_CASE(slice_var_inputs_static_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_static_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_var_inputs_static_error0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {3}};
throws_shape(migraphx::make_op("slice"), input, starts, ends, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_dyn_shape0) TEST_CASE(slice_dyn_shape0)
{ {
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}}; migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
...@@ -2870,7 +2931,7 @@ TEST_CASE(slice_dyn_shape2) ...@@ -2870,7 +2931,7 @@ TEST_CASE(slice_dyn_shape2)
TEST_CASE(slice_dyn_shape3) TEST_CASE(slice_dyn_shape3)
{ {
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min. // TODO: When non-fixed dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error. // Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}}; migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
......
...@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test) ...@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test)
} }
} }
TEST_CASE(slice_var_inputs_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {1};
std::vector<int32_t> end_data = {3};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {-2};
std::vector<int32_t> end_data = {2831};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int64_type, {3}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> start_data = {0, 0, 0};
std::vector<int64_t> end_data = {2, 2, 2};
std::vector<int64_t> axes_data = {0, 1, 2};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
std::vector<int> end_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
{ {
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is // Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
......
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -840,12 +839,8 @@ TEST_CASE(slice_test) ...@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm->add_literal(migraphx::literal{s0, {1, 0}}); mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}}); mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op; mm->add_instruction(
op.starts = {1, 0}; migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0);
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
mm->add_instruction(op, l0);
auto prog = optimize_tf("slice_test.pb", false); auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test) ...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; auto l2 = mm->add_instruction(
migraphx::op::slice op; migraphx::make_op(
op.starts = {0, 0, 0, 0}; "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}),
op.ends = {1, 1, 1, 5}; l1);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test) ...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format) // add literals for starts, ends, and strides in tf (NHWC format)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0}); std::vector<int>{0, 1, 1, 0});
...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test) ...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment