Commit 483c4508 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f8c319e3
...@@ -931,7 +931,7 @@ struct onnx_parser ...@@ -931,7 +931,7 @@ struct onnx_parser
{ {
dirct = op::lstm::reverse; dirct = op::lstm::reverse;
} }
else if (direction == "forward") else if(direction == "forward")
{ {
dirct = op::lstm::forward; dirct = op::lstm::forward;
} }
...@@ -963,9 +963,7 @@ struct onnx_parser ...@@ -963,9 +963,7 @@ struct onnx_parser
// provided. This may need change later // provided. This may need change later
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: case 1: vec_names.insert(vec_names.end(), 5, vec_names.back()); break;
vec_names.insert(vec_names.end(), 5, vec_names.back());
break;
case 2: case 2:
// repeat the 2nd actv func once, then repeat all three another time // repeat the 2nd actv func once, then repeat all three another time
...@@ -978,33 +976,25 @@ struct onnx_parser ...@@ -978,33 +976,25 @@ struct onnx_parser
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end()); vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break; break;
case 4: case 4: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 5: case 5: vec_names.push_back(vec_names.back()); break;
vec_names.push_back(vec_names.back());
break;
default: default: break;
break;
} }
} }
else else
{ {
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: case 1: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 2: case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs // repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back()); vec_names.push_back(vec_names.back());
break; break;
default: default: break;
break;
} }
} }
...@@ -1041,8 +1031,7 @@ struct onnx_parser ...@@ -1041,8 +1031,7 @@ struct onnx_parser
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = prog.add_instruction( auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
std::move(args));
// second output for last lstm output // second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
......
...@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// we have 6 actv funcs, even though a user does not // we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the // specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions // algorithm in parse_lstm to make 6 actv functions
const auto &actv_funcs = lstm_op.actv_funcs; const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size(); std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional) if(lstm_op.direction == op::lstm::bidirectional)
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)}; actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2: case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)}; actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3: case 3:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4: case 4:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(3), actv_funcs.at(3), actv_funcs.at(3)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5: case 5:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(3), actv_funcs.at(4), actv_funcs.at(4)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default: default: return actv_funcs;
return actv_funcs;
} }
} }
else else
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2: case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default: default: return actv_funcs;
return actv_funcs;
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment