Commit 5fe0c226 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9124c4a1
...@@ -446,23 +446,33 @@ TEST_CASE(rnn_test) ...@@ -446,23 +446,33 @@ TEST_CASE(rnn_test)
std::size_t hs = 20; // hidden size std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size std::size_t is = 10; // input size
std::size_t nd = 2; // num directions std::size_t nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
// bidirectional // bidirectional
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); auto bias =
auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::rnn{hs, auto out_hs =
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, p.add_instruction(migraphx::op::rnn{hs,
migraphx::op::rnn::bidirectional, clip}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
seq, w, r, bias, seq_len, ih); migraphx::op::rnn::bidirectional,
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -473,18 +483,28 @@ TEST_CASE(rnn_test) ...@@ -473,18 +483,28 @@ TEST_CASE(rnn_test)
nd = 1; nd = 1;
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); auto bias =
auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::rnn{hs, auto out_hs =
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, p.add_instruction(migraphx::op::rnn{hs,
migraphx::op::rnn::forward, clip}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
seq, w, r, bias, seq_len, ih); migraphx::op::rnn::forward,
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -495,18 +515,28 @@ TEST_CASE(rnn_test) ...@@ -495,18 +515,28 @@ TEST_CASE(rnn_test)
nd = 1; nd = 1;
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); auto bias =
auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::rnn{hs, auto out_hs =
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, p.add_instruction(migraphx::op::rnn{hs,
migraphx::op::rnn::reverse, clip}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
seq, w, r, bias, seq_len, ih); migraphx::op::rnn::reverse,
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -517,15 +547,20 @@ TEST_CASE(rnn_test) ...@@ -517,15 +547,20 @@ TEST_CASE(rnn_test)
nd = 1; nd = 1;
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto out_hs = p.add_instruction(migraphx::op::rnn{hs, auto out_hs =
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, p.add_instruction(migraphx::op::rnn{hs,
migraphx::op::rnn::reverse, clip}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
seq, w, r); migraphx::op::rnn::reverse,
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -536,17 +571,25 @@ TEST_CASE(rnn_test) ...@@ -536,17 +571,25 @@ TEST_CASE(rnn_test)
nd = 1; nd = 1;
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::rnn{hs, auto out_hs =
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, p.add_instruction(migraphx::op::rnn{hs,
migraphx::op::rnn::reverse, clip}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
seq, w, r, bias, ih); migraphx::op::rnn::reverse,
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); clip},
seq,
w,
r,
bias,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment