"src/targets/gpu/quant_gemm.cpp" did not exist on "61566cb13691a60a3f59c714ec37d942f98f36bf"
Commit b8478c6f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9025504f
......@@ -24,14 +24,14 @@ struct rewrite_rnn
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
......
......@@ -79,23 +79,23 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -147,7 +147,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
......@@ -184,14 +185,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment