"...resnet50_tensorflow.git" did not exist on "42bb922cd1566c872c1797c46207a43424b98673"
Commit 148e548d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refinement.

parent 42d2549d
......@@ -25,13 +25,15 @@ struct rewrite_gru
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
int linear_before_reset,
operation& actv_func1,
operation& actv_func2) const;
const operation& actv_func1,
const operation& actv_func2) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -818,10 +818,8 @@ struct onnx_parser
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for(auto& fn : names)
{
vec_names.push_back(fn);
}
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [] (auto &str) { return str; });
}
// need 4 activation functions
......
......@@ -15,6 +15,7 @@ void rewrite_gru::apply(program& prog) const
{
if(ins->name() == "gru")
{
const auto actv_funcs = compute_actv_funcs(ins);
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
......@@ -70,8 +71,8 @@ void rewrite_gru::apply(program& prog) const
bias_forward,
ih_forward,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
......@@ -82,8 +83,8 @@ void rewrite_gru::apply(program& prog) const
bias_reverse,
ih_reverse,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -110,7 +111,7 @@ void rewrite_gru::apply(program& prog) const
}
else
{
bool is_forward = (dicrt == op::gru::forward) ? true : false;
bool is_forward = (dicrt == op::gru::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
......@@ -142,8 +143,8 @@ void rewrite_gru::apply(program& prog) const
bias,
ih,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
actv_funcs.at(0),
actv_funcs.at(1));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
......@@ -155,7 +156,7 @@ void rewrite_gru::apply(program& prog) const
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
......@@ -186,9 +187,10 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
instruction_ref bias,
instruction_ref ih,
int linear_before_reset,
operation& actv_func1,
operation& actv_func2) const
const operation& actv_func1,
const operation& actv_func2) const
{
assert(actv_funcs.size() == 2);
instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]);
......@@ -334,5 +336,46 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
return {hidden_states, last_output};
}
std::vector<operation> rewrite_gru::compute_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if (gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if (gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if (gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment