Unverified Commit 31906785 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

ReverseSequence op (#1177)

Implements the ReverseSequence ONNX operator as a parser.

This parser can only handle a constant sequence_lens input. This is the same as what is handled for TensorRT as far as I can tell.
We could handle a variable sequence_lens input; that would require ref and GPU implementations of the operator.
The ONNX backend tests are disabled because this does not handle variable sequence_lens.
parent 764273e4
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct parse_reversesequence : op_parser<parse_reversesequence>
{
std::vector<op_desc> operators() const { return {{"ReverseSequence"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int batch_axis = 1;
if(contains(info.attributes, "batch_axis"))
{
batch_axis = info.attributes.at("batch_axis").i();
}
if(batch_axis != 0 and batch_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1");
}
int time_axis = 0;
if(contains(info.attributes, "time_axis"))
{
time_axis = info.attributes.at("time_axis").i();
}
if(time_axis != 0 and time_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1");
}
if(time_axis == batch_axis)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same");
}
auto input = args[0];
auto input_lens = input->get_shape().lens();
if(input_lens.size() < 2)
{
MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2");
}
std::vector<int64_t> sequence_lens;
if(args.size() == 2)
{
migraphx::argument seq_lens_arg = args.back()->eval();
check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens");
seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "sequence_lens"))
{
literal s = parser.parse_value(info.attributes.at("sequence_lens"));
s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); });
}
auto batch_size = input_lens[batch_axis];
auto time_size = input_lens[time_axis];
// this condition may still work if sequence_len's shape was incorrect
if(sequence_lens.size() != batch_size)
{
MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape");
}
instruction_ref ret;
auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) {
return info.add_instruction(make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b, t_start}},
{"ends", {b + 1, t_end}}}),
input);
};
for(int b = 0; b < batch_size; ++b)
{
instruction_ref s0;
if(sequence_lens[b] > 1)
{
s0 = add_slice(b, 0, sequence_lens[b]);
s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0);
// if reversed less than whole batch, concat rest of batch
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, sequence_lens[b], time_size);
s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1);
}
}
else
{ // cases where nothing changes
s0 = add_slice(b, 0, time_size);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4385,6 +4385,142 @@ def resize_upsample_pc_test(): ...@@ -4385,6 +4385,142 @@ def resize_upsample_pc_test():
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
@onnx_test
def reversesequence_4D_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2, 2])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=1,
sequence_lens=[2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_batch_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
seq_lens = np.array([1, 2, 3, 4])
seq_lens_tensor = helper.make_tensor(
name="sequence_lens",
data_type=TensorProto.INT64,
dims=seq_lens.shape,
vals=seq_lens.astype(np.int64),
)
arg_seq_lens = helper.make_node(
"Constant",
inputs=[],
outputs=['arg_seq_lens'],
value=seq_lens_tensor,
)
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x', 'arg_seq_lens'],
outputs=['y'],
time_axis=1,
batch_axis=0,
)
return ([arg_seq_lens, node], [x], [y])
@onnx_test
def reversesequence_batch_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=2,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_rank_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_sequence_lens_shape_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
sequence_lens=[4, 3, 2],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_same_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=1,
batch_axis=1,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_time_axis_err_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2, 3])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=3,
batch_axis=0,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test
def reversesequence_time_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node(
'ReverseSequence',
inputs=['x'],
outputs=['y'],
time_axis=0,
batch_axis=1,
sequence_lens=[4, 3, 2, 1],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def roialign_default_test(): def roialign_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8])
......
...@@ -4222,6 +4222,126 @@ TEST_CASE(resize_upsample_pf_test) ...@@ -4222,6 +4222,126 @@ TEST_CASE(resize_upsample_pf_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reversesequence_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
int batch_axis = 0;
int time_axis = 1;
migraphx::shape sx{migraphx::shape::float_type, {4, 4}};
auto input = mm->add_parameter("x", sx);
std::vector<int64_t> sequence_lens = {1, 2, 3, 4};
mm->add_literal({{migraphx::shape::int64_type, {4}}, sequence_lens});
int batch_size = sx.lens()[batch_axis];
int time_size = sx.lens()[time_axis];
auto add_slice =
[&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) {
return mm->add_instruction(migraphx::make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b_start, t_start}},
{"ends", {b_end, t_end}}}),
input);
};
auto ret = add_slice(0, 1, 0, time_size);
for(int b = 1; b < batch_size; ++b)
{
auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]);
s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0);
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size);
s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1);
}
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
mm->add_return({ret});
auto prog = migraphx::parse_onnx("reversesequence_batch_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reversesequence_batch_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_batch_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_rank_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_rank_err_test.onnx"); }));
}
TEST_CASE(reversesequence_sequence_lens_shape_err_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); }));
}
TEST_CASE(reversesequence_same_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_same_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_time_axis_err_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_time_axis_err_test.onnx"); }));
}
TEST_CASE(reversesequence_time_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
int batch_axis = 1;
int time_axis = 0;
migraphx::shape sx{migraphx::shape::float_type, {4, 4}};
auto input = mm->add_parameter("x", sx);
int batch_size = sx.lens()[batch_axis];
int time_size = sx.lens()[time_axis];
std::vector<int64_t> sequence_lens = {4, 3, 2, 1};
auto add_slice =
[&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) {
return mm->add_instruction(migraphx::make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b_start, t_start}},
{"ends", {b_end, t_end}}}),
input);
};
migraphx::instruction_ref ret;
for(int b = 0; b < batch_size - 1; ++b)
{
auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]);
s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0);
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size);
s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
auto s0 = add_slice(batch_size - 1, batch_size, 0, time_size);
ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("reversesequence_time_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(roialign_default_test) TEST_CASE(roialign_default_test)
{ {
migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}}; migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}};
......
reversesequence_rank_err_test:v
3
xy"ReverseSequence*
sequence_lens@@@@reversesequence_rank_err_testZ
x

b
y

B
\ No newline at end of file
"reversesequence_same_axis_err_test:
X
xy"ReverseSequence*
batch_axis*
sequence_lens@@@@*
time_axis"reversesequence_same_axis_err_testZ
x


b
y


B
\ No newline at end of file
,reversesequence_sequence_lens_shape_err_test:‹
1
xy"ReverseSequence*
sequence_lens@@@ ,reversesequence_sequence_lens_shape_err_testZ
x


b
y


B
\ No newline at end of file
...@@ -698,6 +698,69 @@ TEST_CASE(resize_upsample_pf_test) ...@@ -698,6 +698,69 @@ TEST_CASE(resize_upsample_pf_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(reversesequence_4D_verify_test)
{
migraphx::program p = migraphx::parse_onnx("reversesequence_4D_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2, 2}};
std::vector<float> x_data = {
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
migraphx::parameter_map param_map;
param_map["x"] = migraphx::argument(xs, x_data.data());
auto result = p.eval(param_map).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(reversesequence_batch_verify_test)
{
migraphx::program p = migraphx::parse_onnx("reversesequence_batch_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {4, 4}};
std::vector<float> x_data = {
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
migraphx::parameter_map param_map;
param_map["x"] = migraphx::argument(xs, x_data.data());
auto result = p.eval(param_map).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(reversesequence_time_verify_test)
{
migraphx::program p = migraphx::parse_onnx("reversesequence_time_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {4, 4}};
std::vector<float> x_data = {
0.0, 4.0, 8.0, 12.0, 1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0};
migraphx::parameter_map param_map;
param_map["x"] = migraphx::argument(xs, x_data.data());
auto result = p.eval(param_map).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p = migraphx::parse_onnx("selu_test.onnx"); migraphx::program p = migraphx::parse_onnx("selu_test.onnx");
......
...@@ -178,6 +178,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -178,6 +178,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_reduce.*')
backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_ReLU*')
backend_test.include(r'.*test_relu.*') backend_test.include(r'.*test_relu.*')
#backend_test.include(r'.*test_reversesequence.*')
backend_test.include(r'.*test_RoiAlign*') backend_test.include(r'.*test_RoiAlign*')
backend_test.include(r'.*test_roialign.*') backend_test.include(r'.*test_roialign.*')
backend_test.include(r'.*test_scatter.*') backend_test.include(r'.*test_scatter.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment