Commit 3f35b208 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent 56daf147
...@@ -698,8 +698,8 @@ struct onnx_parser ...@@ -698,8 +698,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, map_actv_funcs[activation_func], dirct, clip}, return prog.add_instruction(
std::move(args)); op::rnn{hidden_size, map_actv_funcs[activation_func], dirct, clip}, std::move(args));
} }
instruction_ref instruction_ref
...@@ -739,7 +739,7 @@ struct onnx_parser ...@@ -739,7 +739,7 @@ struct onnx_parser
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
actv_func_names.clear(); actv_func_names.clear();
for (auto &fn : names) for(auto& fn : names)
{ {
actv_func_names.push_back(fn); actv_func_names.push_back(fn);
} }
...@@ -751,45 +751,43 @@ struct onnx_parser ...@@ -751,45 +751,43 @@ struct onnx_parser
} }
// need 4 activation functions // need 4 activation functions
if (dirct == op::gru::bidirectional) if(dirct == op::gru::bidirectional)
{ {
// one name is provided, need to repeat the function 3 times // one name is provided, need to repeat the function 3 times
if (actv_func_names.size() == 1) if(actv_func_names.size() == 1)
{ {
actv_func_names.resize(4, actv_func_names.at(0)); actv_func_names.resize(4, actv_func_names.at(0));
} }
else if (actv_func_names.size() == 2) else if(actv_func_names.size() == 2)
{ {
actv_func_names.insert(actv_func_names.end(), actv_func_names.begin(), actv_func_names.end()); actv_func_names.insert(
actv_func_names.end(), actv_func_names.begin(), actv_func_names.end());
} }
else if (actv_func_names.size() == 3) else if(actv_func_names.size() == 3)
{ {
MIGRAPHX_THROW("GRU: birectional network cannot have 3 activation functions in attribute"); MIGRAPHX_THROW(
"GRU: birectional network cannot have 3 activation functions in attribute");
} }
} }
else else
{ {
if (actv_func_names.size() == 1) if(actv_func_names.size() == 1)
{ {
actv_func_names.push_back(actv_func_names.at(0)); actv_func_names.push_back(actv_func_names.at(0));
} }
} }
for_each(actv_func_names.begin(), actv_func_names.end(), for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) {
[&](auto &name) if(map_actv_funcs.count(name) == 0)
{ {
if (map_actv_funcs.count(name) == 0) MIGRAPHX_THROW("GRU: activation function " + name + " not supported");
{ }
MIGRAPHX_THROW("GRU: activation function " + name + " not supported"); });
}
});
std::vector<operation> vec_actv_funcs; std::vector<operation> vec_actv_funcs;
for_each(actv_func_names.begin(), actv_func_names.end(), for_each(actv_func_names.begin(), actv_func_names.end(), [&](auto& name) {
[&](auto &name) vec_actv_funcs.push_back(map_actv_funcs[name]);
{ });
vec_actv_funcs.push_back(map_actv_funcs[name]);
});
// To be added later // To be added later
float clip = 0.0; float clip = 0.0;
...@@ -805,11 +803,7 @@ struct onnx_parser ...@@ -805,11 +803,7 @@ struct onnx_parser
} }
return prog.add_instruction( return prog.add_instruction(
op::gru{hidden_size, op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
vec_actv_funcs,
dirct,
clip,
linear_before_reset},
std::move(args)); std::move(args));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment