Commit 80016cff authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 846afb76
...@@ -62,29 +62,23 @@ void rewrite_gru::apply(program& prog) const ...@@ -62,29 +62,23 @@ void rewrite_gru::apply(program& prog) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = gru_cell(true, auto ret_forward =
prog, gru_cell(true,
ins, prog,
{args[0], ins,
w_forward, {args[0], w_forward, r_forward, bias_forward, ih_forward},
r_forward, gru_op.linear_before_reset,
bias_forward, actv_funcs.at(0),
ih_forward}, actv_funcs.at(1));
gru_op.linear_before_reset,
actv_funcs.at(0), auto ret_reverse =
actv_funcs.at(1)); gru_cell(false,
prog,
auto ret_reverse = gru_cell(false, ins,
prog, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
ins, gru_op.linear_before_reset,
{args[0], actv_funcs.at(2),
w_reverse, actv_funcs.at(3));
r_reverse,
bias_reverse,
ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
...@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const ...@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction // replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if gru_last_output exists
auto last_output_it = std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) { auto last_output_it =
return i->name() == "gru_last_output"; std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
}); return i->name() == "gru_last_output";
if (last_output_it != ins->outputs().end()) });
if(last_output_it != ins->outputs().end())
{ {
prog.replace_instruction(*last_output_it, last_output); prog.replace_instruction(*last_output_it, last_output);
} }
...@@ -179,11 +174,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -179,11 +174,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
const operation& actv_func2) const const operation& actv_func2) const
{ {
assert(inputs.size() == 5); assert(inputs.size() == 5);
auto seq = inputs.at(0); auto seq = inputs.at(0);
auto w = inputs.at(1); auto w = inputs.at(1);
auto r = inputs.at(2); auto r = inputs.at(2);
auto bias = inputs.at(3); auto bias = inputs.at(3);
auto ih = inputs.at(4); auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end(), last_output; instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(seq->get_shape().lens()[0]); long seq_len = static_cast<long>(seq->get_shape().lens()[0]);
......
...@@ -645,8 +645,10 @@ TEST_CASE(gru_test) ...@@ -645,8 +645,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -677,8 +679,10 @@ TEST_CASE(gru_test) ...@@ -677,8 +679,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -709,26 +713,29 @@ TEST_CASE(gru_test) ...@@ -709,26 +713,29 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(migraphx::op::gru{hs,
p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{},
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}}, migraphx::op::relu{},
migraphx::op::gru::bidirectional, migraphx::op::tanh{}},
clip}, migraphx::op::gru::bidirectional,
seq, clip},
w, seq,
r, w,
bias, r,
seq_len, bias,
ih); seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
...@@ -742,8 +749,10 @@ TEST_CASE(gru_test) ...@@ -742,8 +749,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
...@@ -769,8 +778,10 @@ TEST_CASE(gru_test) ...@@ -769,8 +778,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
...@@ -799,8 +810,10 @@ TEST_CASE(gru_test) ...@@ -799,8 +810,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -831,8 +844,10 @@ TEST_CASE(gru_test) ...@@ -831,8 +844,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -840,10 +855,7 @@ TEST_CASE(gru_test) ...@@ -840,10 +855,7 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip},
{},
migraphx::op::gru::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -863,25 +875,24 @@ TEST_CASE(gru_test) ...@@ -863,25 +875,24 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
{migraphx::op::tanh{}}, seq,
migraphx::op::gru::bidirectional, w,
clip}, r,
seq, bias,
w, seq_len,
r, ih);
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
...@@ -895,8 +906,10 @@ TEST_CASE(gru_test) ...@@ -895,8 +906,10 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
...@@ -927,25 +940,27 @@ TEST_CASE(gru_test) ...@@ -927,25 +940,27 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::gru::bidirectional,
clip}, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
...@@ -959,25 +974,23 @@ TEST_CASE(gru_test) ...@@ -959,25 +974,23 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip},
p.add_instruction(migraphx::op::gru{hs, seq,
{}, w,
migraphx::op::gru::forward, r,
clip}, bias,
seq, seq_len,
w, ih);
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
...@@ -991,25 +1004,24 @@ TEST_CASE(gru_test) ...@@ -991,25 +1004,24 @@ TEST_CASE(gru_test)
auto seq = auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto w =
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len = auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip},
{migraphx::op::relu{}}, seq,
migraphx::op::gru::reverse, w,
clip}, r,
seq, bias,
w, seq_len,
r, ih);
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment