Commit 148e548d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refinement.

parent 42d2549d
...@@ -25,13 +25,15 @@ struct rewrite_gru ...@@ -25,13 +25,15 @@ struct rewrite_gru
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref w,
instruction_ref wh, instruction_ref r,
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
int linear_before_reset, int linear_before_reset,
operation& actv_func1, const operation& actv_func1,
operation& actv_func2) const; const operation& actv_func2) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -818,10 +818,8 @@ struct onnx_parser ...@@ -818,10 +818,8 @@ struct onnx_parser
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
for(auto& fn : names) vec_names.resize(names.size());
{ std::transform(names.begin(), names.end(), vec_names.begin(), [] (auto &str) { return str; });
vec_names.push_back(fn);
}
} }
// need 4 activation functions // need 4 activation functions
......
...@@ -15,6 +15,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -15,6 +15,7 @@ void rewrite_gru::apply(program& prog) const
{ {
if(ins->name() == "gru") if(ins->name() == "gru")
{ {
const auto actv_funcs = compute_actv_funcs(ins);
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so // the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs // we need to process up to 5 inputs
...@@ -70,8 +71,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -70,8 +71,8 @@ void rewrite_gru::apply(program& prog) const
bias_forward, bias_forward,
ih_forward, ih_forward,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(0), actv_funcs.at(0),
gru_op.actv_funcs.at(1)); actv_funcs.at(1));
auto ret_reverse = gru_cell(false, auto ret_reverse = gru_cell(false,
prog, prog,
...@@ -82,8 +83,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -82,8 +83,8 @@ void rewrite_gru::apply(program& prog) const
bias_reverse, bias_reverse,
ih_reverse, ih_reverse,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(2), actv_funcs.at(2),
gru_op.actv_funcs.at(3)); actv_funcs.at(3));
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
...@@ -110,7 +111,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -110,7 +111,7 @@ void rewrite_gru::apply(program& prog) const
} }
else else
{ {
bool is_forward = (dicrt == op::gru::forward) ? true : false; bool is_forward = (dicrt == op::gru::forward);
// weight matrix // weight matrix
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -142,8 +143,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -142,8 +143,8 @@ void rewrite_gru::apply(program& prog) const
bias, bias,
ih, ih,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(0), actv_funcs.at(0),
gru_op.actv_funcs.at(1)); actv_funcs.at(1));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
...@@ -155,7 +156,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -155,7 +156,7 @@ void rewrite_gru::apply(program& prog) const
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[1]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = hidden_state =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
...@@ -186,9 +187,10 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -186,9 +187,10 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
int linear_before_reset, int linear_before_reset,
operation& actv_func1, const operation& actv_func1,
operation& actv_func2) const const operation& actv_func2) const
{ {
assert(actv_funcs.size() == 2);
instruction_ref hidden_states = prog.end(), last_output; instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(input->get_shape().lens()[0]); long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
...@@ -334,5 +336,46 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -334,5 +336,46 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
std::vector<operation> rewrite_gru::compute_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if (gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if (gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if (gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment