Commit 977f032b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3c7b6d27
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program &prog) const void rewrite_rnn::apply(program& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
...@@ -109,8 +109,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -109,8 +109,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
instruction_ref hidden_output{}; instruction_ref hidden_output{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_output = prog.replace_instruction( hidden_output =
ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -118,8 +118,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -118,8 +118,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction( hidden_output =
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -149,8 +149,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -149,8 +149,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = auto ret = rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -165,8 +164,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -165,8 +164,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
...@@ -365,8 +363,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -365,8 +363,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = auto ret_forward = gru_cell(true,
gru_cell(true,
prog, prog,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward}, {args[0], w_forward, r_forward, bias_forward, ih_forward},
...@@ -374,8 +371,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -374,8 +371,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1)); actv_funcs.at(1));
auto ret_reverse = auto ret_reverse = gru_cell(false,
gru_cell(false,
prog, prog,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
...@@ -392,8 +388,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -392,8 +388,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
instruction_ref hidden_state{}; instruction_ref hidden_state{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_state = prog.replace_instruction( hidden_state =
ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -401,8 +397,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -401,8 +397,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction( hidden_state =
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -449,8 +445,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -449,8 +445,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment