Commit 483c4508 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f8c319e3
...@@ -1268,8 +1268,8 @@ struct lstm ...@@ -1268,8 +1268,8 @@ struct lstm
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
lstm_direction_t direction = forward; lstm_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -49,13 +49,13 @@ struct rewrite_rnn ...@@ -49,13 +49,13 @@ struct rewrite_rnn
// for lstm operators // for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const; void apply_lstm(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const; const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
}; };
......
...@@ -931,7 +931,7 @@ struct onnx_parser ...@@ -931,7 +931,7 @@ struct onnx_parser
{ {
dirct = op::lstm::reverse; dirct = op::lstm::reverse;
} }
else if (direction == "forward") else if(direction == "forward")
{ {
dirct = op::lstm::forward; dirct = op::lstm::forward;
} }
...@@ -958,14 +958,12 @@ struct onnx_parser ...@@ -958,14 +958,12 @@ struct onnx_parser
// use the algorithm that: if 1 actv function is provided, // use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided, // repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once // repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once. // if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions // the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later // provided. This may need change later
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: case 1: vec_names.insert(vec_names.end(), 5, vec_names.back()); break;
vec_names.insert(vec_names.end(), 5, vec_names.back());
break;
case 2: case 2:
// repeat the 2nd actv func once, then repeat all three another time // repeat the 2nd actv func once, then repeat all three another time
...@@ -978,33 +976,25 @@ struct onnx_parser ...@@ -978,33 +976,25 @@ struct onnx_parser
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end()); vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break; break;
case 4: case 4: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 5: case 5: vec_names.push_back(vec_names.back()); break;
vec_names.push_back(vec_names.back());
break;
default: default: break;
break;
} }
} }
else else
{ {
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: case 1: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 2: case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs // repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back()); vec_names.push_back(vec_names.back());
break; break;
default: default: break;
break;
} }
} }
...@@ -1041,8 +1031,7 @@ struct onnx_parser ...@@ -1041,8 +1031,7 @@ struct onnx_parser
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = prog.add_instruction( auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
std::move(args));
// second output for last lstm output // second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
......
...@@ -676,13 +676,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -676,13 +676,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const const operation& actv_func3) const
{ {
return {}; return {};
} }
...@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// we have 6 actv funcs, even though a user does not // we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the // specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions // algorithm in parse_lstm to make 6 actv functions
const auto &actv_funcs = lstm_op.actv_funcs; const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size(); std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional) if(lstm_op.direction == op::lstm::bidirectional)
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)}; actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2: case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)}; actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3: case 3:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4: case 4:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(3), actv_funcs.at(3), actv_funcs.at(3)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5: case 5:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2), return {actv_funcs.at(0),
actv_funcs.at(3), actv_funcs.at(4), actv_funcs.at(4)}; actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default: default: return actv_funcs;
return actv_funcs;
} }
} }
else else
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2: case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default: return actv_funcs;
default:
return actv_funcs;
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment