Commit 22f8a479 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

handling the cases that not enough actv functions are provided.

parent 0cc5b80e
...@@ -1140,7 +1140,7 @@ struct rnn ...@@ -1140,7 +1140,7 @@ struct rnn
}; };
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward; rnn_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
......
...@@ -30,6 +30,8 @@ struct rewrite_rnn ...@@ -30,6 +30,8 @@ struct rewrite_rnn
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0); std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = compute_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional) if(dicrt == op::rnn::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const
r_forward, r_forward,
bias_forward, bias_forward,
ih_forward, ih_forward,
rnn_op.actv_funcs.at(0)); actv_funcs.at(0));
auto ret_reverse = rnn_cell(false, auto ret_reverse = rnn_cell(false,
prog, prog,
ins, ins,
...@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const
r_reverse, r_reverse,
bias_reverse, bias_reverse,
ih_reverse, ih_reverse,
rnn_op.actv_funcs.at(1)); actv_funcs.at(1));
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
...@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const
} }
else else
{ {
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward); bool is_forward = (dicrt == op::rnn::forward);
// input weight matrix // input weight matrix
auto w = args[1]; auto w = args[1];
...@@ -135,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -135,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
} }
auto ret = rnn_cell( auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0)); is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -263,5 +264,42 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -263,5 +264,42 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return {hidden_out, last_out}; return {hidden_out, last_out};
} }
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if (rnn_op.direction == op::rnn::bidirectional)
{
if (rnn_op.actv_funcs.size() == 0)
{
// default is tanh
return {op::tanh{}, op::tanh{}};
}
else if (rnn_op.actv_funcs.size() == 1)
{
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
}
else
{
return rnn_op.actv_funcs;
}
}
else
{
if (rnn_op.actv_funcs.size() == 0)
{
// default is tanh
return {op::tanh{}};
}
else
{
return rnn_op.actv_funcs;
}
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -1459,7 +1459,7 @@ TEST_CASE(rnn_forward) ...@@ -1459,7 +1459,7 @@ TEST_CASE(rnn_forward)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {},
migraphx::op::rnn::forward, migraphx::op::rnn::forward,
clip}, clip},
seq, seq,
...@@ -1599,7 +1599,7 @@ TEST_CASE(rnn_reverse) ...@@ -1599,7 +1599,7 @@ TEST_CASE(rnn_reverse)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {},
migraphx::op::rnn::reverse, migraphx::op::rnn::reverse,
clip}, clip},
seq, seq,
...@@ -1724,7 +1724,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1724,7 +1724,7 @@ TEST_CASE(rnn_bidirectional)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {},
migraphx::op::rnn::bidirectional, migraphx::op::rnn::bidirectional,
clip}, clip},
seq, seq,
...@@ -1776,7 +1776,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1776,7 +1776,7 @@ TEST_CASE(rnn_bidirectional)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn::bidirectional,
clip}, clip},
seq, seq,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment