Unverified Commit bf548547 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Parse slice default steps fix (#2306)

Fixes an issue that comes up for variable input slice with steps set manually in ONNX to default 1's.
parent b8202d61
...@@ -144,16 +144,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -144,16 +144,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes; sd.op.axes = axes;
} }
if(not sd.steps.empty()) if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{ {
if(sd.op.starts.empty() or sd.op.ends.empty()) if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty()) if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
} }
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size())) for(auto i : range(sd.steps.size()))
{ {
......
...@@ -8006,6 +8006,32 @@ def slice_var_input_dyn1(): ...@@ -8006,6 +8006,32 @@ def slice_var_input_dyn1():
return ([node], [data, starts, ends, axes], [output]) return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_default_steps():
step = np.array([1, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT64,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['data', 'starts', 'ends', 'axes', 'arg_step'],
outputs=['output'])
return ([arg_step, node], [data, starts, ends, axes], [output])
@onnx_test() @onnx_test()
def slice_var_input_steps_error(): def slice_var_input_steps_error():
step = np.array([2, 1]) step = np.array([2, 1])
...@@ -8019,9 +8045,9 @@ def slice_var_input_steps_error(): ...@@ -8019,9 +8045,9 @@ def slice_var_input_steps_error():
value=step_tensor) value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2]) data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.FLOAT, [2]) starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.FLOAT, [2]) ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.FLOAT, [2]) axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node( node = onnx.helper.make_node(
......
...@@ -7653,6 +7653,25 @@ TEST_CASE(slice_var_input_dyn1) ...@@ -7653,6 +7653,25 @@ TEST_CASE(slice_var_input_dyn1)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_var_input_default_steps)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int64_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int64_type, {2}});
auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {2}});
mm->add_literal({{migraphx::shape::int64_type, {2}}, {1, 1}});
auto ret = mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("slice_var_input_default_steps.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_steps_error) TEST_CASE(slice_var_input_steps_error)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); }));
......
reshape_variable_input_test0:q

0
12"Reshapereshape_variable_input_test0Z
0



Z
1

b
2


B
\ No newline at end of file
...@@ -226,7 +226,6 @@ TEST_CASE(reshape_2in_test1) ...@@ -226,7 +226,6 @@ TEST_CASE(reshape_2in_test1)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_2in_elements_runtime_error) TEST_CASE(reshape_2in_elements_runtime_error)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment