Commit e81dc0e2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed bugs in gru implementation.

parent 3f35b208
...@@ -734,50 +734,53 @@ struct onnx_parser ...@@ -734,50 +734,53 @@ struct onnx_parser
dirct = op::gru::reverse; dirct = op::gru::reverse;
} }
std::vector<std::string> actv_func_names = {"sigmoid", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations")) if(contains(attributes, "activations"))
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
actv_func_names.clear(); vec_names.clear();
for(auto& fn : names) for(auto& fn : names)
{ {
actv_func_names.push_back(fn); vec_names.push_back(fn);
} }
} }
if(actv_func_names.size() != 2)
{
MIGRAPHX_THROW("GRU: wrong activation function attribute");
}
// need 4 activation functions // need 4 activation functions
if(dirct == op::gru::bidirectional) if(dirct == op::gru::bidirectional)
{ {
// one name is provided, need to repeat the function 3 times // 4 activation functions are used in the bidirectional
if(actv_func_names.size() == 1) // scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provides,
// repeat 1 four times. If 2 actv functins are provides,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{ {
actv_func_names.resize(4, actv_func_names.at(0)); vec_names.insert(vec_names.end(), 3, vec_names.at(0));
} }
else if(actv_func_names.size() == 2) else if(vec_names.size() == 2)
{ {
actv_func_names.insert( // repeat the activation functions
actv_func_names.end(), actv_func_names.begin(), actv_func_names.end()); vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
} }
else if(actv_func_names.size() == 3) else if(vec_names.size() == 3)
{ {
MIGRAPHX_THROW( vec_names.push_back(vec_names.at(2));
"GRU: birectional network cannot have 3 activation functions in attribute");
} }
} }
else else
{ {
if(actv_func_names.size() == 1) if(vec_names.size() == 1)
{ {
actv_func_names.push_back(actv_func_names.at(0)); vec_names.push_back(vec_names.at(0));
} }
} }
for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) { for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) if(map_actv_funcs.count(name) == 0)
{ {
MIGRAPHX_THROW("GRU: activation function " + name + " not supported"); MIGRAPHX_THROW("GRU: activation function " + name + " not supported");
...@@ -785,7 +788,7 @@ struct onnx_parser ...@@ -785,7 +788,7 @@ struct onnx_parser
}); });
std::vector<operation> vec_actv_funcs; std::vector<operation> vec_actv_funcs;
for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) { for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
vec_actv_funcs.push_back(map_actv_funcs[name]); vec_actv_funcs.push_back(map_actv_funcs[name]);
}); });
......
...@@ -103,17 +103,15 @@ void rewrite_gru::apply(program& prog) const ...@@ -103,17 +103,15 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
auto final_output = //auto final_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); // prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]); ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output // concat the forward and reverse output
auto replaced_arg =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
replaced_arg->add_output(final_output);
} }
else else
{ {
...@@ -153,8 +151,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -153,8 +151,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
// add the dimension of num_direction // add the dimension of num_direction
auto replaced_arg = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
replaced_arg->add_output(ret[1]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment