#include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { //! Parser for ReverseSequence ONNX operator. /*! Reverses the data along the time axis for the batches along the batch axis. The sequence lengths can be given to reverse up to the given length for each batch, keeping the rest of the sequence in the original order. Variable sequence_lens is not supported in this version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The batch axis and time axis must be [0, 1] and not the same. */ struct parse_reversesequence : op_parser { std::vector operators() const { return {{"ReverseSequence"}}; } instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& parser, const onnx_parser::node_info& info, std::vector args) const { int batch_axis = 1; if(contains(info.attributes, "batch_axis")) { batch_axis = info.attributes.at("batch_axis").i(); } if(batch_axis != 0 and batch_axis != 1) { MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1"); } int time_axis = 0; if(contains(info.attributes, "time_axis")) { time_axis = info.attributes.at("time_axis").i(); } if(time_axis != 0 and time_axis != 1) { MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1"); } if(time_axis == batch_axis) { MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same"); } auto input = args[0]; auto input_lens = input->get_shape().lens(); if(input_lens.size() < 2) { MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2"); } std::vector sequence_lens; if(args.size() == 2) { migraphx::argument seq_lens_arg = args.back()->eval(); check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens"); seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); }); } else if(contains(info.attributes, "sequence_lens")) { literal s = parser.parse_value(info.attributes.at("sequence_lens")); s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); }); } auto batch_size = input_lens[batch_axis]; auto time_size = input_lens[time_axis]; // this condition may still work if sequence_len's shape was incorrect if(sequence_lens.size() != batch_size) { MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape"); } instruction_ref ret; auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) { return info.add_instruction(make_op("slice", {{"axes", {batch_axis, time_axis}}, {"starts", {b, t_start}}, {"ends", {b + 1, t_end}}}), input); }; for(int b = 0; b < batch_size; ++b) { instruction_ref s0; if(sequence_lens[b] > 1) { s0 = add_slice(b, 0, sequence_lens[b]); s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0); // if reversed less than whole batch, concat rest of batch if(sequence_lens[b] < time_size) { auto s1 = add_slice(b, sequence_lens[b], time_size); s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1); } } else { // cases where nothing changes s0 = add_slice(b, 0, time_size); } if(b == 0) { ret = s0; } else { ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0); } } return ret; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx