Unverified Commit d7b8164c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add support variable seq lens for the RNN and GRU operators (#535)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments

* fix a cppcheck error

* clang format

* fix review comments

* clang format

* add an operator to simplify code

* clang format

* clang format

* fixed a bug and add unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* refine a unit test

* clang format

* refine a unit test

* add more unit tests and refine some existing tests for the rnn operator improvements

* clang format

* additional changes to simplify code further

* clang format

* refine a test case to refine cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent c41c3501
......@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
#include <vector>
#include <initializer_list>
#include <migraphx/rank.hpp>
#include <migraphx/config.hpp>
......@@ -129,6 +130,17 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x);
}
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
template <class Range, class Predicate>
std::vector<range_value<Range>> find_all(Range&& r, Predicate p)
{
std::vector<range_value<Range>> result;
std::copy_if(r.begin(), r.end(), std::back_inserter(result), p);
return result;
}
template <class Iterator>
struct iterator_range
{
......
......@@ -27,11 +27,7 @@ struct rewrite_rnn
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
......@@ -75,6 +71,11 @@ struct rewrite_rnn
std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(program& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,8 +15,12 @@
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
......@@ -61,9 +65,19 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction;
op::rnn_direction dirct = rnn_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -97,24 +111,25 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto ret_forward =
vanilla_rnn_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
actv_funcs.at(0));
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret_reverse =
vanilla_rnn_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -138,7 +153,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
bool is_forward = (dirct == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
......@@ -163,8 +178,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
......@@ -182,33 +203,26 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
}
// search its output to find if there are rnn_last_hs_output operator
// while loop to handle case of multiple rnn_last_hs_output operators
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_hs_output_it, last_output);
last_hs_output_it++;
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
std::vector<instruction_ref> inputs,
operation& actv_func) const
{
assert(inputs.size() == 6);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto seq_lens = inputs.at(4);
auto ih = inputs.at(5);
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
......@@ -236,12 +250,12 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
......@@ -340,9 +354,19 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction;
op::rnn_direction dirct = gru_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -375,21 +399,29 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_forward =
gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto ret_reverse =
gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -412,7 +444,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
......@@ -435,10 +467,16 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
}
auto ret = gru_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih},
{args[0], w, r, bias, seq_lens, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
......@@ -457,22 +495,10 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
}
// replace the corresponding rnn_last_hs_output instruction
// with the last_output, if rnn_last_hs_output exists
// while loop to handle case of multiple rnn_last_hs_output operators
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_hs_output_it, last_output);
last_hs_output_it++;
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
......@@ -483,23 +509,23 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
assert(inputs.size() == 6);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto seq_lens = inputs.at(4);
auto ih = inputs.at(5);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data});
migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(ss.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{ss, data});
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
......@@ -535,6 +561,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......@@ -888,7 +915,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);
// replace last hidden states with corresponding instructions
ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
// replace last cell outputs with corresponding instructions
replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
}
......@@ -917,11 +951,9 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long max_seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
migraphx::shape r_shape = r->get_shape();
long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
......@@ -1049,21 +1081,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
}
}
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
// In this case, the cell_output is not used at all, so
// no need to extand it to the avariable length
if(seq_len < max_seq_len)
{
auto s = last_hs_output->get_shape();
auto pad_lens = s.lens();
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens};
std::vector<float> data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, data.begin(), data.end());
hidden_states = prog.insert_instruction(ins, op::concat{0}, hidden_states, pl);
}
return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
}
......@@ -1200,42 +1217,26 @@ instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
result_ins = prog.insert_instruction(
std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens);
prog.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; });
// correct the direction used for the operator
auto last_hs_output_it = result_ins->outputs().begin();
while(last_hs_output_it != result_ins->outputs().end())
for(auto& hs_out : hs_outputs)
{
last_hs_output_it =
std::find_if(last_hs_output_it, result_ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
if(last_hs_output_it != result_ins->outputs().end())
{
auto inputs = (*last_hs_output_it)->inputs();
prog.replace_instruction(*last_hs_output_it,
op::rnn_var_sl_last_output{dirct},
inputs.front(),
seq_lens);
last_hs_output_it++;
}
auto inputs = hs_out->inputs();
prog.replace_instruction(
hs_out, op::rnn_var_sl_last_output{dirct}, inputs.front(), seq_lens);
}
}
else
{
auto last_hs_output_it = ins->outputs().begin();
while(last_hs_output_it != ins->outputs().end())
{
last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_hs_output";
});
auto hs_outputs =
find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
if(last_hs_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_hs_output_it, last_hs_output);
last_hs_output_it++;
}
for(auto& hs_out : hs_outputs)
{
prog.replace_instruction(hs_out, last_hs_output);
}
result_ins = ins;
}
......@@ -1250,14 +1251,12 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
auto ins_outputs =
find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
if(variable_seq_len)
{
auto last_cell_output_it =
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
if(!ins_outputs.empty())
{
cell_outputs =
prog.insert_instruction(std::next(ins),
......@@ -1266,47 +1265,48 @@ void rewrite_rnn::replace_last_cell_output(program& prog,
seq_lens);
}
last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
for(auto co : ins_outputs)
{
last_cell_output_it =
std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
auto inputs = (*last_cell_output_it)->inputs();
inputs[0] = cell_outputs;
prog.replace_instruction(*last_cell_output_it,
op::rnn_var_sl_last_output{dirct},
inputs.front(),
seq_lens);
last_cell_output_it++;
}
prog.replace_instruction(co, op::rnn_var_sl_last_output{dirct}, cell_outputs, seq_lens);
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
// loop is to handle the case of multiple rnn_last_cell_output operators
else
{
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
for(auto co : ins_outputs)
{
last_cell_output_it =
std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
prog.replace_instruction(co, last_cell_output);
}
}
}
instruction_ref rewrite_rnn::pad_hidden_states(program& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const
{
auto max_seq_len = seq->get_shape().lens()[0];
auto seq_len = get_seq_len(prog, seq, seq_lens);
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
auto hs_padded = hs;
if(seq_len < max_seq_len)
{
auto s = hs->get_shape();
auto pad_lens = s.lens();
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = prog.insert_instruction(std::next(hs), op::concat{0}, hs, pl);
prog.replace_instruction(hs, hs_padded);
}
return hs_padded;
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
......
......@@ -73,20 +73,28 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto res_hs = outputs.front();
auto res_lho = outputs.back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> lho_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
......@@ -104,18 +112,34 @@ TEST_CASE(rnn_forward)
0.44193283,
-0.16477929,
-0.11893477};
std::vector<float> lho_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// rnn last output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
......@@ -123,15 +147,27 @@ TEST_CASE(rnn_forward)
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto last_out = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, last_out});
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_last_output = outputs.back();
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data;
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209,
-0.18736027, 0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283,
-0.16477929, -0.11893477, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
......@@ -142,9 +178,9 @@ TEST_CASE(rnn_forward)
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// multiple rnn_last_hs_output operators
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
......@@ -152,7 +188,9 @@ TEST_CASE(rnn_forward)
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{2, 1};
auto sql = p.add_literal(seq_len_s, len_data);
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
......@@ -160,24 +198,41 @@ TEST_CASE(rnn_forward)
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto last_out = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, last_out});
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
auto outputs = p.eval({});
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
auto arg_hs = outputs.front();
auto arg_last_output = outputs.back();
std::vector<float> last_output_data;
std::vector<float> hs_data;
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{0.377808,
0.610551,
0.551685,
-0.588848,
-0.371446,
0.317082,
0.131042,
-0.18736,
0.034457,
0.191679,
-0.394683,
-0.308897,
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 3 args
......@@ -370,6 +425,120 @@ TEST_CASE(rnn_reverse)
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, lho});
p.compile(migraphx::cpu::target{});
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635,
0.14803654, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{2, 1};
auto sql = p.add_literal(seq_len_s, len_data);
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, lho});
p.compile(migraphx::cpu::target{});
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{-0.293853,
0.167968,
0.51076,
0.402587,
-0.0070999,
0.46251,
-0.206392,
0.374889,
-0.0070999,
0.46251,
-0.206392,
0.374889,
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
......@@ -417,17 +586,17 @@ TEST_CASE(rnn_bidirectional)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden state for program output
// concatenation of hidden state and last hs output for program outputs
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
......@@ -435,10 +604,18 @@ TEST_CASE(rnn_bidirectional)
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, lho});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
arg_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> last_output_data;
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236,
......@@ -447,7 +624,26 @@ TEST_CASE(rnn_bidirectional)
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942};
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// last rnn output for program output
......@@ -458,7 +654,9 @@ TEST_CASE(rnn_bidirectional)
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{1, 2};
auto sql = p.add_literal(seq_len_s, len_data);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
......@@ -469,33 +667,44 @@ TEST_CASE(rnn_bidirectional)
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_return({out_hs, lho});
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({}).back();
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.back();
std::vector<float> hs_data;
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
std::vector<float> hs_data_gold{
0.377808, 0.610551, 0.551685, -0.588848, -0.371446, 0.317082, 0.131042, -0.18736,
-0.169158, 0.193817, 0.206679, 0.586097, -0.138188, 0.441244, 0.143656, 0.148037,
0, 0, 0, 0, -0.222764, 0.441933, -0.164779, -0.118935,
0, 0, 0, 0, -0.0070999, 0.46251, -0.206392, 0.374889};
std::vector<float> last_output_data_gold{0.377808,
0.610551,
0.551685,
-0.588848,
-0.222764,
0.441933,
-0.164779,
-0.118935,
-0.169158,
0.193817,
0.206679,
0.586097,
-0.138188,
0.441244,
0.143656,
0.148037};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
......@@ -1262,7 +1471,7 @@ TEST_CASE(gru_reverse)
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// concatenation of hidden states for output
// concatenation of hidden states and last hs output for outputs
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
......@@ -1271,22 +1480,29 @@ TEST_CASE(gru_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto res_hs = outputs.back();
auto res_lho = outputs.front();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> lho_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125,
-0.600874, 0.542386, -0.0856531, 0.55703, 0.54711,
......@@ -1294,20 +1510,34 @@ TEST_CASE(gru_reverse)
-0.187861, 0.213553, -0.0708377, 0.545435, 0.654301,
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
std::vector<float> lho_data_gold{-0.263403,
0.317655,
-0.00634162,
0.200443,
-0.349125,
-0.600874,
0.542386,
-0.0856531,
0.55703,
0.54711};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// last output for output
// variable input sequence length
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{1, 2};
auto sql = p.add_literal(seq_len_s, len_data);
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
......@@ -1317,26 +1547,38 @@ TEST_CASE(gru_reverse)
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto res_hs = outputs.back();
auto res_lho = outputs.front();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> lho_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.263403,
0.317655,
-0.00634162,
0.200443,
-0.349125,
-0.600874,
0.542386,
-0.0856531,
0.55703,
0.54711};
std::vector<float> hs_data_gold{
-0.272984, 0.423637, -0.0936878, 0.482307, -0.0218324, -0.630874, 0.401448, 0.0488417,
0.558397, 0.664423, 0, 0, 0, 0, 0, -0.238202,
-0.0752721, 0.0919409, 0.669654, 0.782363, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
std::vector<float> lho_data_gold{-0.272984,
0.423637,
-0.0936878,
0.482307,
-0.0218324,
-0.630874,
0.401448,
0.0488417,
0.558397,
0.664423};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
......@@ -1532,7 +1774,7 @@ TEST_CASE(gru_bidirectional)
float clip = 0.0f;
// concatenation of hidden states for output
// concatenation of hidden states and last hs output for outputs
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
......@@ -1541,22 +1783,28 @@ TEST_CASE(gru_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto hs_concat = outputs.front();
auto res_lho = outputs.back();
std::vector<float> hs_data;
std::vector<float> lho_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273,
......@@ -1568,19 +1816,91 @@ TEST_CASE(gru_bidirectional)
0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849,
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775};
std::vector<float> lho_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533,
-0.311839, -0.12802, -0.16643, -0.393849, 0.648851,
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// last output for output
// same input sequence length, but shorter than max squence length
{
migraphx::program p;
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.add_return({concat_hs, lho});
p.compile(migraphx::cpu::target{});
auto outputs = p.eval({});
auto hs_concat = outputs.front();
auto res_lho = outputs.back();
std::vector<float> hs_data;
std::vector<float> lho_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683, 0.214342, -0.0454273,
-0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531,
-0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435,
0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906,
-0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945,
0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681,
0.241526, 0.321104, 0.00693531, -0.311839, -0.12802, -0.16643, -0.393849,
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0};
std::vector<float> lho_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693531,
-0.311839, -0.12802, -0.16643, -0.393849, 0.648851,
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// variable input sequence lengths
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{1, 2};
auto sql = p.add_literal(seq_len_s, len_data);
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
......@@ -1591,20 +1911,36 @@ TEST_CASE(gru_bidirectional)
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, concat_hs);
p.add_return({concat_hs, lho});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto hs_concat = outputs.front();
auto res_lho = outputs.back();
std::vector<float> hs_data;
std::vector<float> lho_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533,
-0.311839, -0.12802, -0.16643, -0.393849, 0.648851,
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
std::vector<float> hs_data_gold{
0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683, 0.214342, -0.0454273,
-0.135177, -0.0800739, 0.903659, -0.0271321, 0.624762, -0.117084, 0.509115,
-0.0175078, 0.182457, 0.304506, 0.313825, 0.397697, 0.300873, 0,
0, 0, 0, 0, -0.221983, -0.139656, -0.0938906,
-0.247681, 0.69647, 0, 0, 0, 0, 0,
-0.059911, 0.0552807, 0.306764, 0.794409, 0.194492, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0};
std::vector<float> lho_data_gold{0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683,
-0.221983, -0.139656, -0.0938906, -0.247681, 0.69647,
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
0.182457, 0.304506, 0.313825, 0.397697, 0.300873};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
......@@ -3447,15 +3783,15 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -3470,10 +3806,23 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
ih,
ic,
pph);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto lco = p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
p.add_return({out_hs, lho, lco});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.at(1);
auto arg_lco = outputs.at(2);
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> last_output_data;
std::vector<float> last_cell_data;
arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); });
arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.141643, 0.0451978,
......@@ -3489,59 +3838,39 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last hidden state as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
std::vector<float> last_output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.421857, 0.0459771, -0.144955, 0.0720673,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.141643, 0.0451978, 0.140804, 0.0745128,
0.911307, 0.11468, 0.114449, 0.0196755, -0.262807, 0.275286, 0.358395, 0.266267};
std::vector<float> last_cell_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, 0.737121, 0.134902, -0.303595, 0.241948,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify_range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs = p.add_instruction(
auto seq_orig = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = p.add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = p.add_literal(seq_len_s, len_data);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -3556,16 +3885,53 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs, sql);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto lco = p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.add_return({hs, lho, lco});
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, 0.737121, 0.134902, -0.303595, 0.241948,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
auto outputs = p.eval({});
auto res_hs = outputs.at(0);
auto res_lho = outputs.at(1);
auto res_lco = outputs.at(2);
std::vector<float> hs_data;
std::vector<float> lho_data;
std::vector<float> lco_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459033,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0};
std::vector<float> lho_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
std::vector<float> lco_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify_range(lco_data, lco_data_gold));
}
}
......
......@@ -2611,18 +2611,110 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
};
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
std::vector<int> sl_data{5, 7};
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_rnn_sql_2 : verify_program<test_rnn_sql_2>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq_orig = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
migraphx::shape pad_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto seq_pad = p.add_literal(migraphx::literal{pad_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_pad);
std::vector<int> sl_data(batch_size, static_cast<int>(seq_len));
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
......@@ -2914,6 +3006,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -2921,19 +3014,20 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
std::vector<int> sl_data{5, 9};
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......@@ -2974,7 +3068,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
}
};
struct test_gru_forward_last : verify_program<test_gru_forward_last>
struct test_gru_forward : verify_program<test_gru_forward>
{
migraphx::program create_program() const
{
......@@ -3001,7 +3095,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
......@@ -3012,17 +3106,18 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p;
}
};
struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3036,6 +3131,7 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3043,18 +3139,22 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{3, 2, 1};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p;
}
......@@ -3267,7 +3367,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3281,6 +3381,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3288,9 +3389,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
......@@ -3299,9 +3401,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......@@ -3339,7 +3442,7 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
}
};
struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
struct test_gru_bidirct : verify_program<test_gru_bidirct>
{
migraphx::program create_program() const
{
......@@ -3366,7 +3469,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
......@@ -3377,17 +3480,18 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
};
struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3401,6 +3505,7 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3408,18 +3513,22 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment