Commit 28f8b1e8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed issues for rnn operator.

parent c2b69817
......@@ -395,7 +395,6 @@ struct concat
}
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct slice
......@@ -698,8 +697,6 @@ struct gather
return result;
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot
......
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
......@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if (ins == std::prev(this->end()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if (ins->name() == "rnn_last_output")
{
return replace_instruction(ins, op::identity{}, rep);
}
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
......
......@@ -85,16 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse,
rnn_op.actv_funcs.at(1));
last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if (ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward);
......@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = ret[1];
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if (ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
}
......@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
// if rnn operator is executed, the last_output != prog.end()
if(last_output != prog.end())
{
prog.replace_instruction(ins, op::identity{}, last_output);
prog.replace_instruction(ins, last_output);
last_output = prog.end();
}
else
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: must put after rnn operator");
}
}
}
}
......@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out, last_out;
instruction_ref hidden_out = prog.end(), last_out;
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
......@@ -205,9 +227,15 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimension of sequence length
last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht);
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if (i < seq_len - 1)
{
if(is_forward)
{
hidden_out = (seq_index == 0)
......@@ -221,6 +249,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out);
......
......@@ -54,6 +54,7 @@ struct miopen_apply
program* prog = nullptr;
context ctx{};
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
void check_shape(shape x, instruction_ref i)
{
......@@ -64,6 +65,7 @@ struct miopen_apply
void init()
{
this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_relu>("relu", make_relu);
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs);
......@@ -112,7 +114,7 @@ struct miopen_apply
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{
if(ins == --prog->end() and tag.empty())
if(ins == last and tag.empty())
{
return prog->add_parameter("output", s);
}
......
......@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
......
......@@ -1084,7 +1084,6 @@ struct test_rnn_forward
bias,
ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p;
}
......@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
r,
bias,
ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p;
}
......@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
bias,
ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p;
}
......@@ -1232,4 +1228,6 @@ int main()
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_bidirectional>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment