"results/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "379f31a459906d8857969fadf707be4b8f4e4629"
Commit 977f032b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3c7b6d27
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program &prog) const void rewrite_rnn::apply(program& prog) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
...@@ -109,8 +109,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -109,8 +109,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
instruction_ref hidden_output{}; instruction_ref hidden_output{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_output = prog.replace_instruction( hidden_output =
ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -118,8 +118,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -118,8 +118,8 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction( hidden_output =
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -149,8 +149,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -149,8 +149,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = auto ret = rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -165,8 +164,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -165,8 +164,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
...@@ -365,23 +363,21 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -365,23 +363,21 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = auto ret_forward = gru_cell(true,
gru_cell(true, prog,
prog, ins,
ins, {args[0], w_forward, r_forward, bias_forward, ih_forward},
{args[0], w_forward, r_forward, bias_forward, ih_forward}, gru_op.linear_before_reset,
gru_op.linear_before_reset, actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(1));
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
auto ret_reverse = prog,
gru_cell(false, ins,
prog, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
ins, gru_op.linear_before_reset,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse}, actv_funcs.at(2),
gru_op.linear_before_reset, actv_funcs.at(3));
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
...@@ -392,8 +388,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -392,8 +388,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
instruction_ref hidden_state{}; instruction_ref hidden_state{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_state = prog.replace_instruction( hidden_state =
ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -401,8 +397,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -401,8 +397,8 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction( hidden_state =
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -449,8 +445,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -449,8 +445,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
...@@ -475,12 +470,12 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -475,12 +470,12 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2) const const operation& actv_func2) const
{ {
assert(inputs.size() == 5); assert(inputs.size() == 5);
auto seq = inputs.at(0); auto seq = inputs.at(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment