Unverified Commit 90200619 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Rnn variable seq lengths (#517)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments

* fix a cppcheck error

* clang format

* fix review comments

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 369b9f60
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct lstm_last_cell_output
struct rnn_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
std::string name() const { return "rnn_last_cell_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_last_output
struct rnn_last_hs_output
{
std::string name() const { return "rnn_last_output"; }
std::string name() const { return "rnn_last_hs_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_var_sl_last_output
{
rnn_direction direction = rnn_direction::forward;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.direction, "direction"));
}
std::string name() const { return "rnn_var_sl_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_var_sl_shift_output
{
std::string output_name = "hidden_states";
rnn_direction direction = rnn_direction::forward;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_name, "hidden_states"), f(self.direction, "direction"));
}
std::string name() const { return "rnn_var_sl_shift_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs[0];
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto batch_id = idx[2];
auto d = idx[1];
auto t = idx[0];
auto sl = seq_lens[batch_id];
value_type val = value_type{0};
if(t < sl)
{
auto in_idx = idx;
int offset = (direction == rnn_direction::reverse or d == 1) ? 1 : 0;
in_idx[0] += offset * (max_len - sl);
val = input(in_idx.begin(), in_idx.end());
}
output(idx.begin(), idx.end()) = val;
});
});
});
return result;
}
};
struct rnn_var_sl_shift_sequence
{
std::string name() const { return "rnn_var_sl_shift_sequence"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs[0];
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto b = idx[1];
auto t = idx[0];
auto sl = seq_lens[b];
value_type val = value_type{0};
if(t >= max_len - sl)
{
auto in_idx = idx;
in_idx[0] -= (max_len - sl);
val = input(in_idx.begin(), in_idx.end());
}
output(idx.begin(), idx.end()) = val;
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -66,7 +66,9 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
......
......@@ -6,6 +6,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -57,6 +58,23 @@ struct rewrite_rnn
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -1486,7 +1486,7 @@ struct onnx_parser
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
......@@ -1608,11 +1608,96 @@ struct onnx_parser
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv_func_names)
{
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(actv_func_names.size())
{
case 1:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2)};
break;
case 4:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(3),
actv_func_names.at(3)};
break;
case 5:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(4),
actv_func_names.at(4)};
break;
default: break;
}
}
else
{
switch(actv_func_names.size())
{
case 1:
actv_func_names = {
actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
actv_func_names = {
actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)};
break;
default: break;
}
}
}
std::vector<instruction_ref>
parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
......@@ -1664,83 +1749,7 @@ struct onnx_parser
});
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
lstm_actv_functions(dirct, vec_names);
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
......@@ -1779,11 +1788,10 @@ struct onnx_parser
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
auto last_cell_output = prog.add_instruction(op::rnn_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
......
......@@ -9,7 +9,6 @@
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
......@@ -19,6 +18,8 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -181,19 +182,19 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
// search its output to find if there are rnn_last_hs_output operator
// while loop to handle case of multiple rnn_last_hs_output operators
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_output_it != ins->outputs().end())
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
prog.replace_instruction(*last_hs_output_it, last_output);
last_hs_output_it++;
}
}
}
......@@ -456,20 +457,20 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
// replace the corresponding rnn_last_hs_output instruction
// with the last_output, if rnn_last_hs_output exists
// while loop to handle case of multiple rnn_last_hs_output operators
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_output_it != ins->outputs().end())
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
prog.replace_instruction(*last_hs_output_it, last_output);
last_hs_output_it++;
}
}
}
......@@ -675,8 +676,19 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{};
// process sequence length
instruction_ref seq_lens = prog.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
instruction_ref hidden_state{};
instruction_ref cell_outputs{};
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
......@@ -734,45 +746,70 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
}
auto ret_forward = lstm_cell(
true,
auto ret_forward = lstm_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
{args[0],
w_forward,
r_forward,
bias_forward,
seq_lens,
ih_forward,
ic_forward,
pph_forward},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
auto ret_reverse = lstm_cell(
false,
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret_reverse = lstm_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
{args[0],
w_reverse,
r_reverse,
bias_reverse,
seq_lens,
ih_reverse,
ic_reverse,
pph_reverse},
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_output =
auto concat_hs_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// last cell output
last_cell_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
auto concat_cell_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_hs_output);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
cell_outputs = concat_cell_output;
}
else
{
ret_forward[0] =
ret_forward[1] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
ret_reverse[1] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
ret_forward[3] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_forward[3]);
ret_reverse[3] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[3], ret_reverse[2]);
cell_outputs =
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
}
hidden_state =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[1], ret_reverse[1]});
}
else
{
......@@ -817,60 +854,42 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
pph = args[7];
}
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret = lstm_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih, ic, pph},
{args[0], w, r, bias, seq_lens, ih, ic, pph},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = ret[2];
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[3]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs =
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "lstm_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
}
ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
......@@ -882,22 +901,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
const operation& actv_func3) const
{
// must have 7 args in the input vector
assert(inputs.size() == 7);
assert(inputs.size() == 8);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
auto seq_lens = inputs.at(4);
auto ih = inputs.at(5);
auto ic = inputs.at(6);
auto pph = inputs.at(7);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
instruction_ref cell_outputs = prog.end();
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long max_seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
......@@ -948,6 +970,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......@@ -986,7 +1009,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt;
if(pph != prog.end())
{
......@@ -1002,27 +1024,47 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
sic = cellt;
sih = ht;
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
last_hs_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, cellt);
if(i < seq_len - 1)
{
if(i == 0)
{
hidden_states = last_output;
hidden_states = last_hs_output;
cell_outputs = last_cell_output;
}
else
{
auto concat_arg0 = is_forward ? hidden_states : last_output;
auto concat_arg1 = is_forward ? last_output : hidden_states;
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states =
prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.insert_instruction(ins, op::concat{0}, concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs =
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
}
}
}
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
// In this case, the cell_output is not used at all, so
// no need to extand it to the avariable length
if(seq_len < max_seq_len)
{
auto s = last_hs_output->get_shape();
auto pad_lens = s.lens();
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens};
std::vector<float> data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, data.begin(), data.end());
hidden_states = prog.insert_instruction(ins, op::concat{0}, hidden_states, pl);
}
return {hidden_states, last_output, last_cell_output};
return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
......@@ -1099,6 +1141,172 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const
{
bool is_var_lens = false;
if(seq_lens != prog.end())
{
if(seq_lens->can_eval())
{
auto arg_lens = seq_lens->eval();
std::vector<int64_t> vec_lens;
arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
int64_t l = 0;
if(!vec_lens.empty())
{
l = vec_lens[0];
}
if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
{
is_var_lens = true;
}
}
else
{
is_var_lens = true;
}
}
return is_var_lens;
}
std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const
{
bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape();
auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != prog.end())
{
auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens;
arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
length = vec_lens.empty() ? length : vec_lens[0];
}
return length;
}
instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
instruction_ref result_ins{};
if(variable_seq_len)
{
result_ins = prog.insert_instruction(
std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens);
prog.replace_instruction(ins, result_ins);
// correct the direction used for the operator
auto last_hs_output_it = result_ins->outputs().begin();
while(last_hs_output_it != result_ins->outputs().end())
{
last_hs_output_it =
std::find_if(last_hs_output_it, result_ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_hs_output_it != result_ins->outputs().end())
{
auto inputs = (*last_hs_output_it)->inputs();
prog.replace_instruction(*last_hs_output_it,
op::rnn_var_sl_last_output{dirct},
inputs.front(),
seq_lens);
last_hs_output_it++;
}
}
}
else
{
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_hs_output_it, last_hs_output);
last_hs_output_it++;
}
}
result_ins = ins;
}
return result_ins;
}
void rewrite_rnn::replace_last_cell_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
if(variable_seq_len)
{
auto last_cell_output_it =
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
cell_outputs =
prog.insert_instruction(std::next(ins),
op::rnn_var_sl_shift_output{"cell_outputs", dirct},
cell_outputs,
seq_lens);
}
last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it =
std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
auto inputs = (*last_cell_output_it)->inputs();
inputs[0] = cell_outputs;
prog.replace_instruction(*last_cell_output_it,
op::rnn_var_sl_last_output{dirct},
inputs.front(),
seq_lens);
last_cell_output_it++;
}
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
// loop is to handle the case of multiple rnn_last_cell_output operators
else
{
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it =
std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
}
}
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
......
......@@ -18,6 +18,7 @@
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
......@@ -710,6 +711,52 @@ struct cpu_softmax
}
};
struct cpu_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::rnn_var_sl_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto out_comp_lens = args[0].get_shape().lens();
out_comp_lens[0] = 1;
shape out_comp_s{output_shape.type(), out_comp_lens};
visit_all(result, args[0])([&](auto output, auto input) {
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = out_comp_s.multi(i);
auto b = idx[2];
if(op.direction == op::rnn_direction::reverse or idx[1] == 1)
{
idx[0] = 0;
}
else
{
idx[0] = seq_lens[b] - 1;
}
output[i] = input(idx.begin(), idx.end());
});
});
});
return result;
}
};
struct cpu_apply
{
program* prog;
......@@ -745,6 +792,8 @@ struct cpu_apply
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["softmax"] = extend_op<cpu_softmax<op::softmax>, op::softmax>();
apply_map["rnn_var_sl_last_output"] =
extend_op<cpu_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
}
void apply()
......
......@@ -67,6 +67,7 @@ add_library(migraphx_device
device/sub.cpp
device/tan.cpp
device/tanh.cpp
device/rnn_variable_seq_lens.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
......@@ -117,6 +118,7 @@ add_library(migraphx_gpu
int8_conv_pack.cpp
gemm_impl.cpp
preallocate_param.cpp
rnn_variable_seq_lens.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
......
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rnn_var_sl_shift_sequence(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl)
{
auto output_shape = result.get_shape();
int64_t max_len = output_shape.lens()[0];
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
auto out_s = make_hip_shape<3>(output_shape);
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto t = idx[0];
auto b = idx[1];
auto l = sl_data[b];
auto val = in_data[0];
val = 0;
if(t >= max_len - l)
{
auto in_idx = idx;
in_idx[0] -= (max_len - l);
val = in_data[out_s.index(in_idx)];
}
out_data[i] = val;
});
});
});
}
void rnn_var_sl_shift_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse)
{
auto output_shape = result.get_shape();
int64_t max_len = output_shape.lens()[0];
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
auto out_s = make_hip_shape<4>(output_shape);
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto t = idx[0];
auto d = idx[1];
auto b = idx[2];
auto l = sl_data[b];
auto val = in_data[0];
val = 0;
if(t < l)
{
int offset = (d == 1 or is_reverse) ? 1 : 0;
auto in_idx = idx;
in_idx[0] += offset * (max_len - l);
val = in_data[out_s.index(in_idx)];
}
out_data[i] = val;
});
});
});
}
void rnn_var_sl_last_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse)
{
auto input_shape = arg_hs.get_shape();
auto out_comp_lens = input_shape.lens();
out_comp_lens[0] = 1;
shape out_comp_shape{input_shape.type(), out_comp_lens};
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
auto in_s = make_hip_shape<4>(input_shape);
auto out_s = make_hip_shape<4>(out_comp_shape);
gs_launch(stream, result.get_shape().elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto d = idx[1];
auto b = idx[2];
auto l = sl_data[b];
if(is_reverse or d == 1)
{
idx[0] = 0;
}
else
{
idx[0] = l - 1;
}
out_data[i] = in_data[in_s.index(idx)];
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rnn_var_sl_shift_sequence(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl);
void rnn_var_sl_shift_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse);
void rnn_var_sl_last_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP
#include <migraphx/shape.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rnn_var_sl_shift_sequence
{
op::rnn_var_sl_shift_sequence op;
std::string name() const { return "gpu::rnn_var_sl_shift_sequence"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_rnn_var_sl_shift_output
{
op::rnn_var_sl_shift_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::rnn_var_sl_shift_output"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -72,6 +72,7 @@
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp>
#include <migraphx/gpu/recip.hpp>
#include <migraphx/gpu/rnn_variable_seq_lens.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -184,9 +185,14 @@ struct miopen_apply
add_extend_op<hip_reduce_min, op::reduce_min>("reduce_min");
add_extend_op<hip_reduce_prod, op::reduce_prod>("reduce_prod");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_rnn_var_sl_shift_output, op::rnn_var_sl_shift_output>(
"rnn_var_sl_shift_output");
add_extend_op<hip_rnn_var_sl_shift_sequence, op::rnn_var_sl_shift_sequence>(
"rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(
"rnn_var_sl_last_output");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op();
add_convolution_op();
add_deconvolution_op();
......
#include <migraphx/gpu/rnn_variable_seq_lens.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_rnn_var_sl_shift_output::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_shift_output::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_shift_output(ctx.get_stream().get(),
args.back(),
args.at(0),
args.at(1),
(op.direction == op::rnn_direction::reverse));
return args.back();
}
shape hip_rnn_var_sl_shift_sequence::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_shift_sequence::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_shift_sequence(ctx.get_stream().get(), args.back(), args.at(0), args.at(1));
return args.back();
}
shape hip_rnn_var_sl_last_output::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_last_output::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_last_output(ctx.get_stream().get(),
args.back(),
args.at(0),
args.at(1),
(op.direction == op::rnn_direction::reverse));
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -51,6 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
rewrite_pooling{},
dead_code_elimination{},
eliminate_common_subexpression{},
......
......@@ -4,9 +4,10 @@
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
......@@ -125,7 +126,7 @@ TEST_CASE(rnn_forward)
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -143,7 +144,7 @@ TEST_CASE(rnn_forward)
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// multiple rnn_last_output operators
// multiple rnn_last_hs_output operators
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
......@@ -161,8 +162,7 @@ TEST_CASE(rnn_forward)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -192,7 +192,7 @@ TEST_CASE(rnn_forward)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -353,7 +353,7 @@ TEST_CASE(rnn_reverse)
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -472,7 +472,7 @@ TEST_CASE(rnn_bidirectional)
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -517,7 +517,7 @@ TEST_CASE(rnn_bidirectional)
r,
bias);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
......@@ -734,7 +734,7 @@ TEST_CASE(gru_forward)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -754,7 +754,7 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// two rnn_last_output operators after gru
// two rnn_last_hs_output operators after gru
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
......@@ -775,8 +775,8 @@ TEST_CASE(gru_forward)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -817,7 +817,7 @@ TEST_CASE(gru_forward)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1067,7 +1067,7 @@ TEST_CASE(gru_forward_actv_funcs)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1140,7 +1140,7 @@ TEST_CASE(gru_forward_actv_funcs)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1319,7 +1319,7 @@ TEST_CASE(gru_reverse)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1360,7 +1360,7 @@ TEST_CASE(gru_reverse)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1593,7 +1593,7 @@ TEST_CASE(gru_bidirectional)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1628,7 +1628,7 @@ TEST_CASE(gru_bidirectional)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -1919,7 +1919,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -2030,7 +2030,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
......@@ -2277,7 +2277,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({}).back();
......@@ -2325,7 +2325,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({}).back();
......@@ -2508,6 +2508,63 @@ TEST_CASE(lstm_forward_more)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx::program p;
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// seq_len = 1
{
seq_len = 1;
......@@ -2539,7 +2596,7 @@ TEST_CASE(lstm_forward_more)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
......@@ -2685,6 +2742,105 @@ TEST_CASE(lstm_reverse)
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx::program p;
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// variable sequence lengths
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = p.add_literal(seq_len_s, len_data);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
......@@ -2701,7 +2857,7 @@ TEST_CASE(lstm_reverse)
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
......@@ -2733,7 +2889,7 @@ TEST_CASE(lstm_reverse)
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
......@@ -2882,7 +3038,7 @@ TEST_CASE(lstm_reverse_actv)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3091,7 +3247,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3129,7 +3285,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3212,6 +3368,207 @@ TEST_CASE(lstm_bidirectional)
}
}
TEST_CASE(lstm_bidirectional_var_seq_lens)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
std::vector<int> sl_data{1, 2, 3};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.141643, 0.0451978,
0.140804, 0.0745128, 0.911307, 0.11468, 0.114449, 0.0196755, -0.262807,
0.275286, 0.358395, 0.266267, 0, 0, 0, 0,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, 0, 0, 0, 0, 0.96657, 0.0755112,
0.0620917, -0.264845, -0.128254, 0.125398, 0.0665142, -0.163651, 0,
0, 0, 0, 0, 0, 0, 0,
0.103489, 0.0142918, -0.123408, 0.0401075, 0, 0, 0,
0, 0, 0, 0, 0, -0.0644683, 0.371512,
0.212431, -0.116131, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last hidden state as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.421857, 0.0459771, -0.144955, 0.0720673,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.141643, 0.0451978, 0.140804, 0.0745128,
0.911307, 0.11468, 0.114449, 0.0196755, -0.262807, 0.275286, 0.358395, 0.266267};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs, sql);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, 0.737121, 0.134902, -0.303595, 0.241948,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func)
{
std::size_t batch_size = 3;
......@@ -3339,7 +3696,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3369,7 +3726,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3400,7 +3757,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......@@ -3432,7 +3789,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
......
......@@ -2580,7 +2580,7 @@ struct test_rnn_forward : verify_program<test_rnn_forward>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2622,7 +2622,7 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2663,7 +2663,7 @@ struct test_rnn_two_outputs : verify_program<test_rnn_two_outputs>
bias,
und,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -2850,7 +2850,7 @@ struct test_rnn_5args : verify_program<test_rnn_5args>
r,
bias,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2892,7 +2892,7 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2933,7 +2933,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2968,7 +2968,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3012,7 +3012,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3215,7 +3215,7 @@ struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -3301,7 +3301,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3377,7 +3377,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3616,6 +3616,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
......@@ -3624,9 +3625,9 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto len = p.add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
......@@ -3638,11 +3639,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
w,
r,
bias,
und,
len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output, len);
return p;
}
......@@ -3801,7 +3802,7 @@ struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -3837,8 +3838,8 @@ struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.add_return({hs, last_hs, last_cell});
return p;
......@@ -3995,7 +3996,7 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -4064,7 +4065,7 @@ struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3a
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
return p;
}
......@@ -4115,7 +4116,7 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -4140,13 +4141,15 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{3, 2};
auto sql = p.add_literal(migraphx::literal{migraphx::literal{sl_shape, sl_data}});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
......@@ -4156,7 +4159,7 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
w,
r,
bias,
und,
sql,
ih);
return p;
......@@ -4315,13 +4318,15 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data(batch_size, 2);
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
......@@ -4331,7 +4336,7 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
w,
r,
bias,
und,
sql,
ih);
return p;
......
......@@ -61,7 +61,7 @@ TEST_CASE(rnn_test_bidirectional)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
......@@ -103,7 +103,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
......@@ -129,7 +129,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
......@@ -153,7 +153,7 @@ TEST_CASE(rnn_test_one_direction)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
......@@ -181,7 +181,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
......@@ -225,7 +225,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
......@@ -259,7 +259,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
......@@ -296,7 +296,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
......@@ -335,7 +335,7 @@ TEST_CASE(gru_test_args)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
......@@ -367,7 +367,7 @@ TEST_CASE(gru_test_args)
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
......@@ -404,7 +404,7 @@ TEST_CASE(gru_test_args)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
......@@ -450,7 +450,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
......@@ -487,7 +487,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
......@@ -524,7 +524,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
......@@ -561,7 +561,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
......@@ -595,7 +595,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
......@@ -629,7 +629,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
......@@ -678,7 +678,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog);
......@@ -707,7 +707,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
......@@ -764,7 +764,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
......@@ -793,7 +793,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog);
......@@ -823,7 +823,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog);
......@@ -854,8 +854,8 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog);
......@@ -887,8 +887,8 @@ TEST_CASE(lstm_forward)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog);
......@@ -921,8 +921,8 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog);
......@@ -950,6 +950,7 @@ TEST_CASE(lstm_forward_actv_func)
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
// auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
......@@ -967,7 +968,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog);
......@@ -997,7 +998,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog);
......@@ -1028,8 +1029,8 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog);
......@@ -1078,7 +1079,7 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog);
......@@ -1109,8 +1110,8 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog);
......@@ -1139,7 +1140,7 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog);
......@@ -1192,7 +1193,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog);
......@@ -1225,7 +1226,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog);
......@@ -1259,7 +1260,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog);
......@@ -1294,7 +1295,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog);
......@@ -1330,7 +1331,7 @@ TEST_CASE(lstm_bidirectional)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog);
......@@ -1367,7 +1368,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog);
......@@ -1417,7 +1418,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog);
......@@ -1451,7 +1452,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog);
......@@ -1486,7 +1487,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog);
......@@ -1522,7 +1523,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog);
......@@ -1559,7 +1560,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog);
......@@ -1592,7 +1593,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment