Commit 977f032b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3c7b6d27
......@@ -8,7 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program &prog) const
void rewrite_rnn::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
......@@ -109,8 +109,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
hidden_output = prog.replace_instruction(
ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
hidden_output =
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
......@@ -118,8 +118,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
hidden_output =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
......@@ -149,8 +149,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto ret = rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
......@@ -165,8 +164,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
......@@ -365,23 +363,21 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward =
gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse =
gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -392,8 +388,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
instruction_ref hidden_state{};
if(ret_forward[0] == prog.end())
{
hidden_state = prog.replace_instruction(
ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
hidden_state =
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
......@@ -401,8 +397,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
hidden_state =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
......@@ -449,8 +445,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
......@@ -475,12 +470,12 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment