Unverified Commit ccf491a4 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Slice axes fix (#498)

* fix axes bug, reorder and add test

* formatting

* add missing test file
parent 3f3885ac
......@@ -814,6 +814,13 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
return prog.add_instruction(op, args[0]);
}
......
......@@ -1797,6 +1797,53 @@ def sinh_test():
return ([node], [x], [y])
@onnx_test
def slice_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['0'],
axes=[0, 1],
starts=[1, 0],
ends=[2, 2],
outputs=['1'])
return ([node], [x], [y])
@onnx_test
def slice_3arg_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5])
start = np.array([0, 0])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
end = np.array([2, 5])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
node = onnx.helper.make_node('Slice',
inputs=['0', 'arg_start', 'arg_end'],
outputs=['1'])
return ([arg_start, arg_end, node], [x], [y])
@onnx_test
def slice_5arg_test():
step = np.array([1, 1])
......@@ -1865,21 +1912,6 @@ def slice_max_end_test():
return ([node], [x], [y])
@onnx_test
def slice_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['0'],
axes=[0, 1],
starts=[1, 0],
ends=[2, 2],
outputs=['1'])
return ([node], [x], [y])
@onnx_test
def softmax_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3])
......
......@@ -1387,6 +1387,30 @@ TEST_CASE(sinh_test)
EXPECT(p == prog);
}
TEST_CASE(slice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = optimize_onnx("slice_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_3arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {0, 0}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {2, 5}});
auto ret = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 5}}, l0);
p.add_return({ret});
auto prog = migraphx::parse_onnx("slice_3arg_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_5arg_test)
{
migraphx::program p;
......@@ -1413,16 +1437,6 @@ TEST_CASE(slice_max_end_test)
EXPECT(p == prog);
}
TEST_CASE(slice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = optimize_onnx("slice_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(softmax_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment