Commit 56daf147 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix a bug related to activation functions

parent 9ac6a4a8
...@@ -31,7 +31,7 @@ struct onnx_parser ...@@ -31,7 +31,7 @@ struct onnx_parser
bool is_pytorch = false; bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> actv_funcs; std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser() onnx_parser()
{ {
...@@ -94,11 +94,11 @@ struct onnx_parser ...@@ -94,11 +94,11 @@ struct onnx_parser
void init_actv_func() void init_actv_func()
{ {
actv_funcs.insert(std::make_pair("tanh", op::tanh{})); map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
actv_funcs.insert(std::make_pair("relu", op::relu{})); map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{})); map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
actv_funcs.insert(std::make_pair("elu", op::elu{})); map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
} }
template <class F> template <class F>
...@@ -669,7 +669,7 @@ struct onnx_parser ...@@ -669,7 +669,7 @@ struct onnx_parser
activation_func = attributes.at("activations").strings(0); activation_func = attributes.at("activations").strings(0);
} }
if(actv_funcs.count(activation_func) == 0) if(map_actv_funcs.count(activation_func) == 0)
{ {
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported"); MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
} }
...@@ -698,7 +698,7 @@ struct onnx_parser ...@@ -698,7 +698,7 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, actv_funcs[activation_func], dirct, clip}, return prog.add_instruction(op::rnn{hidden_size, map_actv_funcs[activation_func], dirct, clip},
std::move(args)); std::move(args));
} }
...@@ -734,25 +734,62 @@ struct onnx_parser ...@@ -734,25 +734,62 @@ struct onnx_parser
dirct = op::gru::reverse; dirct = op::gru::reverse;
} }
std::vector<std::string> act_funcs = {"sigmoid", "tanh"}; std::vector<std::string> actv_func_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations")) if(contains(attributes, "activations"))
{ {
act_funcs[0] = attributes.at("activations").strings(0); auto names = attributes.at("activations").strings();
act_funcs[1] = attributes.at("activations").strings(1); actv_func_names.clear();
for (auto &fn : names)
{
actv_func_names.push_back(fn);
}
} }
if(act_funcs.size() != 2) if(actv_func_names.size() != 2)
{ {
MIGRAPHX_THROW("GRU: wrong activation function attribute"); MIGRAPHX_THROW("GRU: wrong activation function attribute");
} }
for(std::size_t i = 0; i < act_funcs.size(); ++i) // need 4 activation functions
if (dirct == op::gru::bidirectional)
{ {
if(actv_funcs.count(act_funcs.at(i)) == 0) // one name is provided, need to repeat the function 3 times
if (actv_func_names.size() == 1)
{
actv_func_names.resize(4, actv_func_names.at(0));
}
else if (actv_func_names.size() == 2)
{ {
MIGRAPHX_THROW("GRU: activation function " + act_funcs.at(i) + " not supported"); actv_func_names.insert(actv_func_names.end(), actv_func_names.begin(), actv_func_names.end());
}
else if (actv_func_names.size() == 3)
{
MIGRAPHX_THROW("GRU: birectional network cannot have 3 activation functions in attribute");
} }
} }
else
{
if (actv_func_names.size() == 1)
{
actv_func_names.push_back(actv_func_names.at(0));
}
}
for_each(actv_func_names.begin(), actv_func_names.end(),
[&](auto &name)
{
if (map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + name + " not supported");
}
});
std::vector<operation> vec_actv_funcs;
for_each(actv_func_names.begin(), actv_func_names.end(),
[&](auto &name)
{
vec_actv_funcs.push_back(map_actv_funcs[name]);
});
// To be added later // To be added later
float clip = 0.0; float clip = 0.0;
...@@ -769,7 +806,7 @@ struct onnx_parser ...@@ -769,7 +806,7 @@ struct onnx_parser
return prog.add_instruction( return prog.add_instruction(
op::gru{hidden_size, op::gru{hidden_size,
{actv_funcs[act_funcs.at(0)], actv_funcs[act_funcs.at(1)]}, vec_actv_funcs,
dirct, dirct,
clip, clip,
linear_before_reset}, linear_before_reset},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment