Commit e238ad93 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 97568d53
...@@ -682,7 +682,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -682,7 +682,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
migraphx::shape pph_shape{type, {1, 3 * hidden_size}}; migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> pph_data(pph_shape.elements(), 0.0); std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto actv_funcs = lstm_actv_funcs(ins); auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction; op::lstm::lstm_direction_t dirct = lstm_op.direction;
...@@ -802,14 +802,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -802,14 +802,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if (args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// initial hidden state // initial hidden state
instruction_ref ih{}; instruction_ref ih{};
if (args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 && args[5]->name() != "undefined")
{ {
ih = args[5]; ih = args[5];
} }
...@@ -820,7 +820,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -820,7 +820,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// initial cell value // initial cell value
instruction_ref ic{}; instruction_ref ic{};
if (args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 && args[6]->name() != "undefined")
{ {
ic = args[6]; ic = args[6];
} }
...@@ -840,8 +840,8 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -840,8 +840,8 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{ {
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data}); pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
} }
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, prog,
ins, ins,
{args[0], w, r, bias, ih, ic, pph}, {args[0], w, r, bias, ih, ic, pph},
...@@ -850,9 +850,9 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -850,9 +850,9 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[2]); last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[2]);
if (ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
...@@ -866,7 +866,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -866,7 +866,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// replace the corresponding lstm_last_output instruction // replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with // with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case // the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output // of multiple lstm_last_output and lstm_last_cell_output
// operators // operators
auto last_output_it = ins->outputs().begin(); auto last_output_it = ins->outputs().begin();
...@@ -909,17 +909,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -909,17 +909,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
// must have 7 args in the input vector // must have 7 args in the input vector
assert(inputs.size() == 7); assert(inputs.size() == 7);
auto seq = inputs.at(0); auto seq = inputs.at(0);
auto w = inputs.at(1); auto w = inputs.at(1);
auto r = inputs.at(2); auto r = inputs.at(2);
auto bias = inputs.at(3); auto bias = inputs.at(3);
auto ih = inputs.at(4); auto ih = inputs.at(4);
auto ic = inputs.at(5); auto ic = inputs.at(5);
auto pph = inputs.at(6); auto pph = inputs.at(6);
instruction_ref instruction_ref
return {}; return {};
} }
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment