Unverified Commit c4e53a33 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add parsing split operator (#460)



* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* fix cppcheck error

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 330224aa
This diff is collapsed.
......@@ -1771,6 +1771,37 @@ def softmax_test():
return ([node], [x], [y])
@onnx_test
def split_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[7, 4, 4])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_default():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [5, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [5, 15])
node = onnx.helper.make_node(
'Split',
inputs=['x'],
outputs=['y1', 'y2'],
)
return ([node], [x], [y1, y2])
@onnx_test
def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
......@@ -1356,6 +1356,31 @@ TEST_CASE(softmax_test)
EXPECT(p == prog);
}
TEST_CASE(split_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{1}, {0}, {7}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{1}, {7}, {11}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{1}, {11}, {15}}, input);
p.add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{0}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{0}, {5}, {10}}, input);
p.add_return({r1, r2});
auto prog = migraphx::parse_onnx("split_test_default.onnx");
EXPECT(p == prog);
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
......

split_test:
5
xy1y2y3"Split*
axis*
split@@@
split_testZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_default:i

xy1y2"Splitsplit_test_defaultZ
x


b
y1


b
y2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment