Commit e238ad93 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 97568d53
......@@ -682,7 +682,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto actv_funcs = lstm_actv_funcs(ins);
auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction;
......@@ -802,14 +802,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// bias
instruction_ref bias = prog.end();
if (args.size() >= 4 && args[3]->name() != "undefined")
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// initial hidden state
instruction_ref ih{};
if (args.size() >= 6 && args[5]->name() != "undefined")
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
......@@ -820,7 +820,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// initial cell value
instruction_ref ic{};
if (args.size() >= 7 && args[6]->name() != "undefined")
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic = args[6];
}
......@@ -840,8 +840,8 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret = lstm_cell(is_forward,
auto ret = lstm_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih, ic, pph},
......@@ -850,9 +850,9 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
actv_funcs.at(1),
actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[2]);
if (ret[0] == prog.end())
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
......@@ -866,7 +866,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// operators
auto last_output_it = ins->outputs().begin();
......@@ -909,17 +909,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
// must have 7 args in the input vector
assert(inputs.size() == 7);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
instruction_ref
instruction_ref
return {};
return {};
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment