Commit 80016cff authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 846afb76
...@@ -62,26 +62,20 @@ void rewrite_gru::apply(program& prog) const ...@@ -62,26 +62,20 @@ void rewrite_gru::apply(program& prog) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = gru_cell(true, auto ret_forward =
gru_cell(true,
prog, prog,
ins, ins,
{args[0], {args[0], w_forward, r_forward, bias_forward, ih_forward},
w_forward,
r_forward,
bias_forward,
ih_forward},
gru_op.linear_before_reset, gru_op.linear_before_reset,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1)); actv_funcs.at(1));
auto ret_reverse = gru_cell(false, auto ret_reverse =
gru_cell(false,
prog, prog,
ins, ins,
{args[0], {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
w_reverse,
r_reverse,
bias_reverse,
ih_reverse},
gru_op.linear_before_reset, gru_op.linear_before_reset,
actv_funcs.at(2), actv_funcs.at(2),
actv_funcs.at(3)); actv_funcs.at(3));
...@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const ...@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction // replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if gru_last_output exists
auto last_output_it = std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) { auto last_output_it =
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output"; return i->name() == "gru_last_output";
}); });
if (last_output_it != ins->outputs().end()) if(last_output_it != ins->outputs().end())
{ {
prog.replace_instruction(*last_output_it, last_output); prog.replace_instruction(*last_output_it, last_output);
} }
......
...@@ -645,8 +645,10 @@ TEST_CASE(gru_test) ...@@ -645,8 +645,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -677,8 +679,10 @@ TEST_CASE(gru_test) ...@@ -677,8 +679,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -709,18 +713,21 @@ TEST_CASE(gru_test) ...@@ -709,18 +713,21 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(migraphx::op::gru{hs,
p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{},
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}}, migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
clip}, clip},
seq, seq,
...@@ -742,8 +749,10 @@ TEST_CASE(gru_test) ...@@ -742,8 +749,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
...@@ -769,8 +778,10 @@ TEST_CASE(gru_test) ...@@ -769,8 +778,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
...@@ -799,8 +810,10 @@ TEST_CASE(gru_test) ...@@ -799,8 +810,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -831,8 +844,10 @@ TEST_CASE(gru_test) ...@@ -831,8 +844,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -840,10 +855,7 @@ TEST_CASE(gru_test) ...@@ -840,10 +855,7 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip},
{},
migraphx::op::gru::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -863,19 +875,18 @@ TEST_CASE(gru_test) ...@@ -863,19 +875,18 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
{migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -895,8 +906,10 @@ TEST_CASE(gru_test) ...@@ -895,8 +906,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -927,16 +940,18 @@ TEST_CASE(gru_test) ...@@ -927,16 +940,18 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
clip}, clip},
...@@ -959,19 +974,17 @@ TEST_CASE(gru_test) ...@@ -959,19 +974,17 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip},
p.add_instruction(migraphx::op::gru{hs,
{},
migraphx::op::gru::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -991,19 +1004,18 @@ TEST_CASE(gru_test) ...@@ -991,19 +1004,18 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip},
{migraphx::op::relu{}},
migraphx::op::gru::reverse,
clip},
seq, seq,
w, w,
r, r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment