Commit e81dc0e2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed bugs in gru implementation.

parent 3f35b208
......@@ -734,50 +734,53 @@ struct onnx_parser
dirct = op::gru::reverse;
}
std::vector<std::string> actv_func_names = {"sigmoid", "tanh"};
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
actv_func_names.clear();
vec_names.clear();
for(auto& fn : names)
{
actv_func_names.push_back(fn);
vec_names.push_back(fn);
}
}
if(actv_func_names.size() != 2)
{
MIGRAPHX_THROW("GRU: wrong activation function attribute");
}
// need 4 activation functions
if(dirct == op::gru::bidirectional)
{
// one name is provided, need to repeat the function 3 times
if(actv_func_names.size() == 1)
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provides,
// repeat 1 four times. If 2 actv functins are provides,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
actv_func_names.resize(4, actv_func_names.at(0));
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(actv_func_names.size() == 2)
else if(vec_names.size() == 2)
{
actv_func_names.insert(
actv_func_names.end(), actv_func_names.begin(), actv_func_names.end());
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(actv_func_names.size() == 3)
else if(vec_names.size() == 3)
{
MIGRAPHX_THROW(
"GRU: birectional network cannot have 3 activation functions in attribute");
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(actv_func_names.size() == 1)
if(vec_names.size() == 1)
{
actv_func_names.push_back(actv_func_names.at(0));
vec_names.push_back(vec_names.at(0));
}
}
for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) {
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + name + " not supported");
......@@ -785,7 +788,7 @@ struct onnx_parser
});
std::vector<operation> vec_actv_funcs;
for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) {
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
vec_actv_funcs.push_back(map_actv_funcs[name]);
});
......
......@@ -103,17 +103,15 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
auto final_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
//auto final_output =
// prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
auto replaced_arg =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
replaced_arg->add_output(final_output);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
else
{
......@@ -153,8 +151,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(1));
// add the dimension of num_direction
auto replaced_arg = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
replaced_arg->add_output(ret[1]);
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment