Commit 6d8fcb3d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 233f3bcc
...@@ -685,13 +685,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -685,13 +685,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
migraphx::shape pph_shape{type, {1, 3 * hidden_size}}; migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> ppl_data(pph_shape.elements(), 0.0); std::vector<float> ppl_data(pph_shape.elements(), 0.0);
auto &actv_funcs = lstm_actv_funcs(ins); auto& actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction; op::lstm::lstm_direction_t dirct = lstm_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
instruction_ref last_cell_output{}; instruction_ref last_cell_output{};
if (dirct == op::lstm::bidirectional) if(dirct == op::lstm::bidirectional)
{ {
// input weight matrix // input weight matrix
// input weight matrix // input weight matrix
...@@ -705,7 +705,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -705,7 +705,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if (args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -728,7 +728,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -728,7 +728,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process initial cell value // process initial cell value
instruction_ref ic_forward{}; instruction_ref ic_forward{};
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if (args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 && args[6]->name() != "undefined")
{ {
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]); ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]); ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
...@@ -742,7 +742,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -742,7 +742,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward{}; instruction_ref pph_forward{};
instruction_ref pph_reverse{}; instruction_ref pph_reverse{};
if (args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
...@@ -753,43 +753,46 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -753,43 +753,46 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data}); pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data});
} }
auto ret_forward = lstm_cell(true, prog, ins, auto ret_forward = lstm_cell(
{args[0], w_forward, r_forward, bias_forward, true,
ih_forward, ic_forward, pph_forward}, prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
lstm_op.input_forget, lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
auto ret_reverse = lstm_cell(false, prog, ins, auto ret_reverse = lstm_cell(
{args[0], w_reverse, r_reverse, bias_reverse, false,
ih_reverse, ic_reverse, pph_reverse}, prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
lstm_op.input_forget, lstm_op.input_forget,
actv_funcs.at(3), actv_funcs.at(3),
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// last cell output // last cell output
auto concat_cell_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[2], ret_reverse[2]); auto concat_cell_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[2], ret_reverse[2]);
last_cell_output = prog.insert_instruction(ins, squeeze{{0}}, concat_cell_output); last_cell_output = prog.insert_instruction(ins, squeeze{{0}}, concat_cell_output);
// the following logic is to ensure the last instruction is a concat // the following logic is to ensure the last instruction is a concat
if (ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
} }
} }
else else
{ {
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment