Commit 56daf147 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix a bug related to activation functions

parent 9ac6a4a8
......@@ -31,7 +31,7 @@ struct onnx_parser
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> actv_funcs;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser()
{
......@@ -94,11 +94,11 @@ struct onnx_parser
void init_actv_func()
{
actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
actv_funcs.insert(std::make_pair("relu", op::relu{}));
actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
actv_funcs.insert(std::make_pair("elu", op::elu{}));
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
}
template <class F>
......@@ -669,7 +669,7 @@ struct onnx_parser
activation_func = attributes.at("activations").strings(0);
}
if(actv_funcs.count(activation_func) == 0)
if(map_actv_funcs.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
......@@ -698,7 +698,7 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(op::rnn{hidden_size, actv_funcs[activation_func], dirct, clip},
return prog.add_instruction(op::rnn{hidden_size, map_actv_funcs[activation_func], dirct, clip},
std::move(args));
}
......@@ -734,25 +734,62 @@ struct onnx_parser
dirct = op::gru::reverse;
}
std::vector<std::string> act_funcs = {"sigmoid", "tanh"};
std::vector<std::string> actv_func_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
act_funcs[0] = attributes.at("activations").strings(0);
act_funcs[1] = attributes.at("activations").strings(1);
auto names = attributes.at("activations").strings();
actv_func_names.clear();
for (auto &fn : names)
{
actv_func_names.push_back(fn);
}
}
if(act_funcs.size() != 2)
if(actv_func_names.size() != 2)
{
MIGRAPHX_THROW("GRU: wrong activation function attribute");
}
for(std::size_t i = 0; i < act_funcs.size(); ++i)
// need 4 activation functions
if (dirct == op::gru::bidirectional)
{
if(actv_funcs.count(act_funcs.at(i)) == 0)
// one name is provided, need to repeat the function 3 times
if (actv_func_names.size() == 1)
{
actv_func_names.resize(4, actv_func_names.at(0));
}
else if (actv_func_names.size() == 2)
{
MIGRAPHX_THROW("GRU: activation function " + act_funcs.at(i) + " not supported");
actv_func_names.insert(actv_func_names.end(), actv_func_names.begin(), actv_func_names.end());
}
else if (actv_func_names.size() == 3)
{
MIGRAPHX_THROW("GRU: birectional network cannot have 3 activation functions in attribute");
}
}
else
{
if (actv_func_names.size() == 1)
{
actv_func_names.push_back(actv_func_names.at(0));
}
}
for_each(actv_func_names.begin(), actv_func_names.end(),
[&](auto &name)
{
if (map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + name + " not supported");
}
});
std::vector<operation> vec_actv_funcs;
for_each(actv_func_names.begin(), actv_func_names.end(),
[&](auto &name)
{
vec_actv_funcs.push_back(map_actv_funcs[name]);
});
// To be added later
float clip = 0.0;
......@@ -769,7 +806,7 @@ struct onnx_parser
return prog.add_instruction(
op::gru{hidden_size,
{actv_funcs[act_funcs.at(0)], actv_funcs[act_funcs.at(1)]},
vec_actv_funcs,
dirct,
clip,
linear_before_reset},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment