Commit b8478c6f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9025504f
...@@ -24,14 +24,14 @@ struct rewrite_rnn ...@@ -24,14 +24,14 @@ struct rewrite_rnn
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const; void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref w, instruction_ref w,
instruction_ref r, instruction_ref r,
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
......
...@@ -79,23 +79,23 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -79,23 +79,23 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
} }
auto ret_forward = vanilla_rnn_cell(true, auto ret_forward = vanilla_rnn_cell(true,
prog, prog,
ins, ins,
args[0], args[0],
w_forward, w_forward,
r_forward, r_forward,
bias_forward, bias_forward,
ih_forward, ih_forward,
actv_funcs.at(0)); actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false, auto ret_reverse = vanilla_rnn_cell(false,
prog, prog,
ins, ins,
args[0], args[0],
w_reverse, w_reverse,
r_reverse, r_reverse,
bias_reverse, bias_reverse,
ih_reverse, ih_reverse,
actv_funcs.at(1)); actv_funcs.at(1));
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
...@@ -147,7 +147,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -147,7 +147,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0)); auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -184,14 +185,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -184,14 +185,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
} }
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref w, instruction_ref w,
instruction_ref r, instruction_ref r,
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
operation& actv_func) const operation& actv_func) const
{ {
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment