Commit 483c4508 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f8c319e3
......@@ -931,7 +931,7 @@ struct onnx_parser
{
dirct = op::lstm::reverse;
}
else if (direction == "forward")
else if(direction == "forward")
{
dirct = op::lstm::forward;
}
......@@ -963,9 +963,7 @@ struct onnx_parser
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 5, vec_names.back());
break;
case 1: vec_names.insert(vec_names.end(), 5, vec_names.back()); break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
......@@ -978,33 +976,25 @@ struct onnx_parser
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break;
case 4:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 4: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
case 5:
vec_names.push_back(vec_names.back());
break;
case 5: vec_names.push_back(vec_names.back()); break;
default:
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 1: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back());
break;
default:
break;
default: break;
}
}
......@@ -1041,8 +1031,7 @@ struct onnx_parser
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget},
std::move(args));
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
......
......@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const auto &actv_funcs = lstm_op.actv_funcs;
const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional)
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{},
op::sigmoid{}, op::tanh{}, op::tanh{}};
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
return {actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(3), actv_funcs.at(3)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(4), actv_funcs.at(4)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default:
return actv_funcs;
default: return actv_funcs;
}
}
else
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default:
return actv_funcs;
default: return actv_funcs;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment