Commit f2436c3d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine the processing of activation function attribute.

parent e7e82505
......@@ -1068,7 +1068,7 @@ struct rnn
};
std::size_t hidden_size = 1;
operation actv_func{tanh{}};
std::vector<operation> actv_funcs{tanh{}};
rnn_direction_t direction = forward;
float clip = 0.0f;
......
......@@ -31,7 +31,7 @@ struct onnx_parser
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> actv_funcs;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser()
{
......@@ -93,11 +93,11 @@ struct onnx_parser
void init_actv_func()
{
actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
actv_funcs.insert(std::make_pair("relu", op::relu{}));
actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
actv_funcs.insert(std::make_pair("elu", op::elu{}));
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
}
template <class F>
......@@ -663,17 +663,6 @@ struct onnx_parser
MIGRAPHX_THROW("RNN: hidden size attribute missing");
}
std::string activation_func = {"tanh"};
if(contains(attributes, "activations"))
{
activation_func = attributes.at("activations").strings(0);
}
if(actv_funcs.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
......@@ -691,6 +680,37 @@ struct onnx_parser
dirct = op::rnn::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto &fn) { vec_names.push_back(fn); } );
}
for_each(vec_names.begin(), vec_names.end(), [&] (auto &fn) {
if(map_actv_funcs.count(fn) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + fn + " not supported");
}
});
// bidirectional should have two activation functions
// if only one actv function is provides, we use it in both
// forward and reverse direction
if (dirct == op::rnn::bidirectional)
{
if (vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs;
for_each(vec_names.begin(), vec_names.end(), [&] (auto &fn) {
vec_actv_funcs.push_back(map_actv_funcs[fn]);
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
......@@ -698,7 +718,7 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(op::rnn{hidden_size, actv_funcs[activation_func], dirct, clip},
return prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
}
......
......@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_forward,
ih_forward,
bias_forward,
rnn_op.actv_func);
rnn_op.actv_funcs.at(0));
auto ret_reverse = rnn_oper(false,
prog,
ins,
......@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_reverse,
ih_reverse,
bias_reverse,
rnn_op.actv_func);
rnn_op.actv_funcs.at(1));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
......@@ -161,7 +161,7 @@ void rewrite_rnn::apply(program& prog) const
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_funcs.at(0));
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment