Commit e41908ad authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent 22f8a479
...@@ -135,8 +135,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -135,8 +135,8 @@ void rewrite_rnn::apply(program& prog) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = rnn_cell( auto ret =
is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0)); rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -264,21 +264,20 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -264,21 +264,20 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return {hidden_out, last_out}; return {hidden_out, last_out};
} }
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{ {
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure // before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the // we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions // algorithm in parse_rnn to make 2 actv functions
if (rnn_op.direction == op::rnn::bidirectional) if(rnn_op.direction == op::rnn::bidirectional)
{ {
if (rnn_op.actv_funcs.size() == 0) if(rnn_op.actv_funcs.size() == 0)
{ {
// default is tanh // default is tanh
return {op::tanh{}, op::tanh{}}; return {op::tanh{}, op::tanh{}};
} }
else if (rnn_op.actv_funcs.size() == 1) else if(rnn_op.actv_funcs.size() == 1)
{ {
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)}; return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
} }
...@@ -289,7 +288,7 @@ std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) cons ...@@ -289,7 +288,7 @@ std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) cons
} }
else else
{ {
if (rnn_op.actv_funcs.size() == 0) if(rnn_op.actv_funcs.size() == 0)
{ {
// default is tanh // default is tanh
return {op::tanh{}}; return {op::tanh{}};
......
...@@ -1458,10 +1458,7 @@ TEST_CASE(rnn_forward) ...@@ -1458,10 +1458,7 @@ TEST_CASE(rnn_forward)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip},
{},
migraphx::op::rnn::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -1598,10 +1595,7 @@ TEST_CASE(rnn_reverse) ...@@ -1598,10 +1595,7 @@ TEST_CASE(rnn_reverse)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::reverse, clip},
{},
migraphx::op::rnn::reverse,
clip},
seq, seq,
w, w,
r, r,
...@@ -1723,10 +1717,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -1723,10 +1717,8 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(
{}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::bidirectional, clip},
migraphx::op::rnn::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -1774,11 +1766,9 @@ TEST_CASE(rnn_bidirectional) ...@@ -1774,11 +1766,9 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, migraphx::op::rnn{
{migraphx::op::tanh{}}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
migraphx::op::rnn::bidirectional,
clip},
seq, seq,
w, w,
r, r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment