Commit 28f8b1e8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed issues for rnn operator.

parent c2b69817
...@@ -395,7 +395,6 @@ struct concat ...@@ -395,7 +395,6 @@ struct concat
} }
return result; return result;
} }
int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct slice struct slice
...@@ -698,8 +697,6 @@ struct gather ...@@ -698,8 +697,6 @@ struct gather
return result; return result;
} }
int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct dot struct dot
......
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
...@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
if (ins == std::prev(this->end()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if (ins->name() == "rnn_last_output")
{
return replace_instruction(ins, op::identity{}, rep);
}
}
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
if(ins->outputs().empty()) if(ins->outputs().empty())
{ {
......
...@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const ...@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse, ih_reverse,
rnn_op.actv_funcs.at(1)); rnn_op.actv_funcs.at(1));
last_output = auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// add the dimension of num_direction // The following logic is to ensure the last instruction rewritten from
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); // rnn operator is a concat instruction
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]); // sequence len is 1
if (ret_forward[0] == prog.end())
// concat the forward and reverse output {
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
} }
else else
{ {
...@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const ...@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
auto ret = rnn_cell( auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0)); is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = ret[1]; last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// add the dimension of num_direction // following logic is to ensure the last instruction is a
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); // concat instruction
// sequence len is 1
if (ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
} }
} }
...@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const ...@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
// if rnn operator is executed, the last_output != prog.end() // if rnn operator is executed, the last_output != prog.end()
if(last_output != prog.end()) if(last_output != prog.end())
{ {
prog.replace_instruction(ins, op::identity{}, last_output); prog.replace_instruction(ins, last_output);
last_output = prog.end(); last_output = prog.end();
} }
else
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: must put after rnn operator");
}
} }
} }
} }
...@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
} }
instruction_ref hidden_out, last_out; instruction_ref hidden_out = prog.end(), last_out;
std::size_t seq_len = input->get_shape().lens()[0]; std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
...@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
ht = prog.insert_instruction(ins, actv_func, ht); ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht; sih = ht;
// add the dimension of sequence length // add the dimensions of sequence length (axis 0 for sequence length,
last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht); // axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
if(is_forward) // concatenation for the last last_out is performed in the apply()
{ // function to ensure the last instruction is concat, then we have
hidden_out = (seq_index == 0) // output inserted
? last_out if (i < seq_len - 1)
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{ {
hidden_out = (seq_index == seq_len - 1) if(is_forward)
? last_out {
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
} }
} }
......
...@@ -54,6 +54,7 @@ struct miopen_apply ...@@ -54,6 +54,7 @@ struct miopen_apply
program* prog = nullptr; program* prog = nullptr;
context ctx{}; context ctx{};
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
void check_shape(shape x, instruction_ref i) void check_shape(shape x, instruction_ref i)
{ {
...@@ -64,6 +65,7 @@ struct miopen_apply ...@@ -64,6 +65,7 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_relu>("relu", make_relu); add_miopen_simple_op<miopen_relu>("relu", make_relu);
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid); add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
...@@ -112,7 +114,7 @@ struct miopen_apply ...@@ -112,7 +114,7 @@ struct miopen_apply
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "") instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{ {
if(ins == --prog->end() and tag.empty()) if(ins == last and tag.empty())
{ {
return prog->add_parameter("output", s); return prog->add_parameter("output", s);
} }
......
...@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse) ...@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
TEST_CASE(rnn_bidirectional) TEST_CASE(rnn_bidirectional)
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
std::size_t seq_len = 2; std::size_t seq_len = 1;
std::size_t hidden_size = 4; std::size_t hidden_size = 4;
std::size_t input_size = 3; std::size_t input_size = 3;
std::size_t num_dirct = 2; std::size_t num_dirct = 2;
......
...@@ -1084,7 +1084,6 @@ struct test_rnn_forward ...@@ -1084,7 +1084,6 @@ struct test_rnn_forward
bias, bias,
ih); ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output); auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p; return p;
} }
...@@ -1124,8 +1123,6 @@ struct test_rnn_reverse ...@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
r, r,
bias, bias,
ih); ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p; return p;
} }
...@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional ...@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
bias, bias,
ih); ih);
auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output); auto last = p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::add{}, last, last);
return p; return p;
} }
...@@ -1232,4 +1228,6 @@ int main() ...@@ -1232,4 +1228,6 @@ int main()
verify_program<test_gather>(); verify_program<test_gather>();
verify_program<test_gather_neg_axis>(); verify_program<test_gather_neg_axis>();
verify_program<test_rnn_forward>(); verify_program<test_rnn_forward>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_bidirectional>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment