Commit 483c4508 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f8c319e3
......@@ -1268,8 +1268,8 @@ struct lstm
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
lstm_direction_t direction = forward;
float clip = 0.0f;
int input_forget = 0;
float clip = 0.0f;
int input_forget = 0;
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -49,13 +49,13 @@ struct rewrite_rnn
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
};
......
......@@ -931,7 +931,7 @@ struct onnx_parser
{
dirct = op::lstm::reverse;
}
else if (direction == "forward")
else if(direction == "forward")
{
dirct = op::lstm::forward;
}
......@@ -958,14 +958,12 @@ struct onnx_parser
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 5, vec_names.back());
break;
case 1: vec_names.insert(vec_names.end(), 5, vec_names.back()); break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
......@@ -978,33 +976,25 @@ struct onnx_parser
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break;
case 4:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 4: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
case 5:
vec_names.push_back(vec_names.back());
break;
case 5: vec_names.push_back(vec_names.back()); break;
default:
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 1: vec_names.insert(vec_names.end(), 2, vec_names.back()); break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back());
break;
default:
break;
default: break;
}
}
......@@ -1041,8 +1031,7 @@ struct onnx_parser
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget},
std::move(args));
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
......
......@@ -676,13 +676,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
{
return {};
}
......@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const auto &actv_funcs = lstm_op.actv_funcs;
const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional)
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{},
op::sigmoid{}, op::tanh{}, op::tanh{}};
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
return {actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(3), actv_funcs.at(3)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(4), actv_funcs.at(4)};
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default:
return actv_funcs;
default: return actv_funcs;
}
}
else
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default:
return actv_funcs;
case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default: return actv_funcs;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment