"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bd0bd7ef2cdd81ad60022607af723b257ee4f7a4"
Unverified Commit d7b8164c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add support variable seq lens for the RNN and GRU operators (#535)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments

* fix a cppcheck error

* clang format

* fix review comments

* clang format

* add an operator to simplify code

* clang format

* clang format

* fixed a bug and add unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* refine a unit test

* clang format

* refine a unit test

* add more unit tests and refine some existing tests for the rnn operator improvements

* clang format

* additional changes to simplify code further

* clang format

* refine a test case to refine cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent c41c3501
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm> #include <algorithm>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -129,6 +130,17 @@ void replace(Range&& r, const T& old, const T& new_x) ...@@ -129,6 +130,17 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
template <class Range, class Predicate>
std::vector<range_value<Range>> find_all(Range&& r, Predicate p)
{
std::vector<range_value<Range>> result;
std::copy_if(r.begin(), r.end(), std::back_inserter(result), p);
return result;
}
template <class Iterator> template <class Iterator>
struct iterator_range struct iterator_range
{ {
......
...@@ -27,11 +27,7 @@ struct rewrite_rnn ...@@ -27,11 +27,7 @@ struct rewrite_rnn
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, std::vector<instruction_ref> inputs,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
...@@ -75,6 +71,11 @@ struct rewrite_rnn ...@@ -75,6 +71,11 @@ struct rewrite_rnn
std::size_t std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const; get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(program& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
This diff is collapsed.
This diff is collapsed.
...@@ -2611,8 +2611,7 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10> ...@@ -2611,8 +2611,7 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
...@@ -2622,7 +2621,100 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10> ...@@ -2622,7 +2621,100 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
};
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
std::vector<int> sl_data{5, 7};
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_rnn_sql_2 : verify_program<test_rnn_sql_2>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq_orig = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
migraphx::shape pad_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto seq_pad = p.add_literal(migraphx::literal{pad_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_pad);
std::vector<int> sl_data(batch_size, static_cast<int>(seq_len));
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p; return p;
} }
...@@ -2914,6 +3006,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> ...@@ -2914,6 +3006,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = p.add_parameter("seq", in_shape);
...@@ -2921,9 +3014,9 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> ...@@ -2921,9 +3014,9 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); std::vector<int> sl_data{5, 9};
auto output = auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
p.add_instruction(migraphx::op::rnn{hidden_size, auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
...@@ -2931,9 +3024,10 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10> ...@@ -2931,9 +3024,10 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
w, w,
r, r,
bias, bias,
und, sql,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p; return p;
} }
...@@ -2974,7 +3068,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args> ...@@ -2974,7 +3068,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
} }
}; };
struct test_gru_forward_last : verify_program<test_gru_forward_last> struct test_gru_forward : verify_program<test_gru_forward>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -3001,7 +3095,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last> ...@@ -3001,7 +3095,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
...@@ -3012,17 +3106,18 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last> ...@@ -3012,17 +3106,18 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p; return p;
} }
}; };
struct test_gru_forward_hs : verify_program<test_gru_forward_hs> struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
std::size_t batch_size = 2; std::size_t batch_size = 3;
std::size_t seq_len = 3; std::size_t seq_len = 3;
std::size_t hidden_size = 5; std::size_t hidden_size = 5;
std::size_t input_size = 8; std::size_t input_size = 8;
...@@ -3036,6 +3131,7 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs> ...@@ -3036,6 +3131,7 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = p.add_parameter("seq", in_shape);
...@@ -3043,8 +3139,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs> ...@@ -3043,8 +3139,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); std::vector<int> sl_data{3, 2, 1};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
...@@ -3053,8 +3151,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs> ...@@ -3053,8 +3151,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
w, w,
r, r,
bias, bias,
und, sql,
ih); ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p; return p;
} }
...@@ -3267,7 +3367,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -3267,7 +3367,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
std::size_t batch_size = 2; std::size_t batch_size = 3;
std::size_t seq_len = 3; std::size_t seq_len = 3;
std::size_t hidden_size = 5; std::size_t hidden_size = 5;
std::size_t input_size = 8; std::size_t input_size = 8;
...@@ -3281,6 +3381,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -3281,6 +3381,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = p.add_parameter("seq", in_shape);
...@@ -3288,9 +3389,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -3288,9 +3389,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto output = auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
...@@ -3299,9 +3401,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last> ...@@ -3299,9 +3401,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
w, w,
r, r,
bias, bias,
und, sql,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p; return p;
} }
...@@ -3339,7 +3442,7 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args> ...@@ -3339,7 +3442,7 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
} }
}; };
struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last> struct test_gru_bidirct : verify_program<test_gru_bidirct>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -3366,7 +3469,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last> ...@@ -3366,7 +3469,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
...@@ -3377,17 +3480,18 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last> ...@@ -3377,17 +3480,18 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output); auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p; return p;
} }
}; };
struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs> struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
std::size_t batch_size = 2; std::size_t batch_size = 3;
std::size_t seq_len = 3; std::size_t seq_len = 3;
std::size_t hidden_size = 5; std::size_t hidden_size = 5;
std::size_t input_size = 8; std::size_t input_size = 8;
...@@ -3401,6 +3505,7 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs> ...@@ -3401,6 +3505,7 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}}; {num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = p.add_parameter("seq", in_shape);
...@@ -3408,8 +3513,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs> ...@@ -3408,8 +3513,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
...@@ -3418,8 +3525,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs> ...@@ -3418,8 +3525,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
w, w,
r, r,
bias, bias,
und, sql,
ih); ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment