Unverified Commit d7b8164c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add support variable seq lens for the RNN and GRU operators (#535)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments

* fix a cppcheck error

* clang format

* fix review comments

* clang format

* add an operator to simplify code

* clang format

* clang format

* fixed a bug and add unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* add more unit tests

* clang format

* refine a unit test

* clang format

* refine a unit test

* add more unit tests and refine some existing tests for the rnn operator improvements

* clang format

* additional changes to simplify code further

* clang format

* refine a test case to refine cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent c41c3501
......@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
#include <vector>
#include <initializer_list>
#include <migraphx/rank.hpp>
#include <migraphx/config.hpp>
......@@ -129,6 +130,17 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x);
}
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
template <class Range, class Predicate>
std::vector<range_value<Range>> find_all(Range&& r, Predicate p)
{
std::vector<range_value<Range>> result;
std::copy_if(r.begin(), r.end(), std::back_inserter(result), p);
return result;
}
template <class Iterator>
struct iterator_range
{
......
......@@ -27,11 +27,7 @@ struct rewrite_rnn
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
......@@ -75,6 +71,11 @@ struct rewrite_rnn
std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(program& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
This diff is collapsed.
This diff is collapsed.
......@@ -2611,8 +2611,7 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
......@@ -2622,7 +2621,100 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
};
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
std::vector<int> sl_data{5, 7};
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_rnn_sql_2 : verify_program<test_rnn_sql_2>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq_orig = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
migraphx::shape pad_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto seq_pad = p.add_literal(migraphx::literal{pad_s, pad_data});
auto seq = p.add_instruction(migraphx::op::concat{0}, seq_orig, seq_pad);
std::vector<int> sl_data(batch_size, static_cast<int>(seq_len));
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto ih = p.add_parameter("ih", ih_shape);
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
......@@ -2914,6 +3006,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -2921,9 +3014,9 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
std::vector<int> sl_data{5, 9};
auto sql = p.add_literal(migraphx::literal{s_shape, sl_data});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
......@@ -2931,9 +3024,10 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......@@ -2974,7 +3068,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
}
};
struct test_gru_forward_last : verify_program<test_gru_forward_last>
struct test_gru_forward : verify_program<test_gru_forward>
{
migraphx::program create_program() const
{
......@@ -3001,7 +3095,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
......@@ -3012,17 +3106,18 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p;
}
};
struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
struct test_var_sl_gru_forward : verify_program<test_var_sl_gru_forward>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3036,6 +3131,7 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3043,8 +3139,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{3, 2, 1};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
......@@ -3053,8 +3151,10 @@ struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
w,
r,
bias,
und,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({lho, hs});
return p;
}
......@@ -3267,7 +3367,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3281,6 +3381,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3288,9 +3389,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
......@@ -3299,9 +3401,10 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
w,
r,
bias,
und,
sql,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......@@ -3339,7 +3442,7 @@ struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
}
};
struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
struct test_gru_bidirct : verify_program<test_gru_bidirct>
{
migraphx::program create_program() const
{
......@@ -3366,7 +3469,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
......@@ -3377,17 +3480,18 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
};
struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
struct test_var_sl_gru_bidirct : verify_program<test_var_sl_gru_bidirct>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t batch_size = 3;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
......@@ -3401,6 +3505,7 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
......@@ -3408,8 +3513,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{2, 1, 3};
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
auto hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
......@@ -3418,8 +3525,10 @@ struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
w,
r,
bias,
und,
sql,
ih);
auto lho = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, lho});
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment