"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bf5485479c2e495413b59d98a5aef36ba99ed944"
Unverified Commit 70d0e816 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Add opset-13 support for parse_split (#1429)

Newer split moves the split attribute to an input. In this case we check the
number of input args then.
parent ba0913b1
......@@ -26,6 +26,9 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
else if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Split: dynamic shape is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
......@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
vec_splits.resize(info.num_outputs, dl);
}
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens));
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
......
......@@ -5687,6 +5687,92 @@ def split_test_default():
return ([node], [x], [y1, y2])
@onnx_test
def split_test_no_attribute():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4) * 75
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_no_attribute_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4)
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[1, 1, 1])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_no_attribute_invalid_input_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[])
return ([node], [x], [y1, y2, y3])
@onnx_test
def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
......@@ -5537,6 +5537,31 @@ TEST_CASE(split_test)
EXPECT(p == prog);
}
TEST_CASE(split_test_no_attribute)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {4}, {1}};
std::vector<int> ind = {75, 75, 75, 75};
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {300, 15}});
mm->add_literal(migraphx::literal(si, ind));
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {75}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {75}}, {"ends", {150}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {150}}, {"ends", {225}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {225}}, {"ends", {300}}}), input);
mm->add_return({r1, r2, r3, r4});
auto prog = migraphx::parse_onnx("split_test_no_attribute.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default)
{
migraphx::program p;
......@@ -5552,6 +5577,23 @@ TEST_CASE(split_test_default)
EXPECT(p == prog);
}
TEST_CASE(split_test_no_attribute_invalid_split)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("split_test_no_attribute_invalid_split.onnx"); }));
}
TEST_CASE(split_test_invalid_split)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_split.onnx"); }));
}
TEST_CASE(split_test_no_attribute_invalid_input_split)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_test_no_attribute_invalid_input_split.onnx"); }));
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
......
split_test_invalid_split:
5
xy1y2y3"Split*
axis*
split@@@split_test_invalid_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_no_attribute:
0split"Constant*
value*:KKKKBsplit
!
x
splity1y2y3y4"Splitsplit_test_no_attributeZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
+split_test_no_attribute_invalid_input_split:
/
xy1y2y3"Split*
axis*
split+split_test_no_attribute_invalid_input_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
%split_test_no_attribute_invalid_split:
0split"Constant*
value*:Bsplit
!
x
splity1y2y3y4"Split%split_test_no_attribute_invalid_splitZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment