Commit 80016cff authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 846afb76
......@@ -62,29 +62,23 @@ void rewrite_gru::apply(program& prog) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0],
w_forward,
r_forward,
bias_forward,
ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto ret_forward =
gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse =
gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
......@@ -159,10 +153,11 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
auto last_output_it = std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output";
});
if (last_output_it != ins->outputs().end())
auto last_output_it =
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
}
......@@ -179,11 +174,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(seq->get_shape().lens()[0]);
......
......@@ -645,8 +645,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
......@@ -677,8 +679,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
......@@ -709,26 +713,29 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs = p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
......@@ -742,8 +749,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
......@@ -769,8 +778,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
......@@ -799,8 +810,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
......@@ -831,8 +844,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
......@@ -840,10 +855,7 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{},
migraphx::op::gru::bidirectional,
clip},
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip},
seq,
w,
r,
......@@ -863,25 +875,24 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
......@@ -895,8 +906,10 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
......@@ -927,25 +940,27 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
......@@ -959,25 +974,23 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{},
migraphx::op::gru::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
......@@ -991,25 +1004,24 @@ TEST_CASE(gru_test)
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}},
migraphx::op::gru::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment