Commit fa62ad24 authored by Paul's avatar Paul
Browse files

Fix slice error when data input is constant

parent 24d9ecbe
...@@ -44,6 +44,11 @@ struct parse_slice : op_parser<parse_slice> ...@@ -44,6 +44,11 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps; std::vector<int64_t> steps;
std::vector<int64_t> raxes; std::vector<int64_t> raxes;
void always_insert(instruction_ref arg)
{
op_args.insert(op_args.begin(), arg);
}
std::vector<int64_t> insert(instruction_ref arg) std::vector<int64_t> insert(instruction_ref arg)
{ {
std::vector<int64_t> result; std::vector<int64_t> result;
...@@ -132,7 +137,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -132,7 +137,7 @@ struct parse_slice : op_parser<parse_slice>
} }
// data input argument // data input argument
sd.insert(args.at(0)); sd.always_insert(args.at(0));
// If axes arg is not given, the default is all of them. // If axes arg is not given, the default is all of them.
if(sd.op.axes.empty() and sd.op_args.size() < 3) if(sd.op.axes.empty() and sd.op_args.size() < 3)
......
...@@ -6413,6 +6413,28 @@ def slice_test(): ...@@ -6413,6 +6413,28 @@ def slice_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def slice_constant_test():
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2])
x_tensor = helper.make_tensor(name='x_tensor',
data_type=TensorProto.FLOAT,
dims=[3, 2],
vals=[0, 1, 2, 3, 4, 5])
x = onnx.helper.make_node('Constant',
inputs=[],
outputs=['x'],
value=x_tensor)
node = onnx.helper.make_node('Slice',
inputs=['x'],
axes=[0, 1],
starts=[1, 0],
ends=[2, 2],
outputs=['1'])
return ([x, node], [], [y])
@onnx_test() @onnx_test()
def slice_dyn_test(): def slice_dyn_test():
......
...@@ -6294,6 +6294,18 @@ TEST_CASE(slice_test) ...@@ -6294,6 +6294,18 @@ TEST_CASE(slice_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3, 2}}, {0, 1, 2, 3, 4, 5}});
mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {2, 2}}}), l0);
auto prog = optimize_onnx("slice_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_dyn_test) TEST_CASE(slice_dyn_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment