Commit 20f89fcc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 0fe4c56b
......@@ -733,8 +733,8 @@ struct onnx_parser
result.push_back(hidden_states);
// second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output);
// auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// result.push_back(last_output);
return result;
}
......
......@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{
auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
......@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse,
rnn_op.actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
......@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{
ih = (args.size() == 6) ? args[5] : args[4];
}
......@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = rnn_cell(is_forward,
prog,
ins,
args[0],
w,
r,
bias,
ih,
rnn_op.actv_funcs.at(0));
auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = ret[1];
// add the dimension of num_direction
......@@ -140,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
//if (ins->name() == "rnn_last_output")
// if (ins->name() == "rnn_last_output")
//{
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
......@@ -175,13 +171,12 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if (bias != prog.end())
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb =
prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment