Commit 4089bcb6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent dc57c9c1
...@@ -24,12 +24,12 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -24,12 +24,12 @@ TEST_CASE(rnn_test_bidirectional)
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -66,12 +66,12 @@ TEST_CASE(rnn_test_one) ...@@ -66,12 +66,12 @@ TEST_CASE(rnn_test_one)
// forward // forward
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -93,12 +93,12 @@ TEST_CASE(rnn_test_one) ...@@ -93,12 +93,12 @@ TEST_CASE(rnn_test_one)
// reverse // reverse
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
...@@ -120,8 +120,8 @@ TEST_CASE(rnn_test_one) ...@@ -120,8 +120,8 @@ TEST_CASE(rnn_test_one)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -144,12 +144,12 @@ TEST_CASE(rnn_test_one) ...@@ -144,12 +144,12 @@ TEST_CASE(rnn_test_one)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -594,44 +594,46 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -594,44 +594,46 @@ TEST_CASE(gru_test_actv_funcs)
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size std::size_t is = 10; // input size
std::size_t nd = 1; // num directions std::size_t nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4*hs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4*hs, hs}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3*hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, hs,
migraphx::op::rnn_direction::forward, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
clip, input_forget}, migraphx::op::rnn_direction::forward,
seq, clip,
w, input_forget},
r, seq,
bias, w,
seq_len, r,
ih, bias,
ic, seq_len,
pph); ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx");
...@@ -642,44 +644,46 @@ TEST_CASE(lstm_forward) ...@@ -642,44 +644,46 @@ TEST_CASE(lstm_forward)
TEST_CASE(lstm_reverse) TEST_CASE(lstm_reverse)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size std::size_t is = 10; // input size
std::size_t nd = 1; // num directions std::size_t nd = 1; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4*hs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4*hs, hs}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3*hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, hs,
migraphx::op::rnn_direction::reverse, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
clip, input_forget}, migraphx::op::rnn_direction::reverse,
seq, clip,
w, input_forget},
r, seq,
bias, w,
seq_len, r,
ih, bias,
ic, seq_len,
pph); ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx");
...@@ -690,44 +694,46 @@ TEST_CASE(lstm_reverse) ...@@ -690,44 +694,46 @@ TEST_CASE(lstm_reverse)
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size std::size_t is = 10; // input size
std::size_t nd = 2; // num directions std::size_t nd = 2; // num directions
float clip = 0.0f; float clip = 0.0f;
int input_forget = 1; int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4*hs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4*hs, hs}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3*hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape); auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, hs,
migraphx::op::rnn_direction::bidirectional, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
clip, input_forget}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, input_forget},
r, seq,
bias, w,
seq_len, r,
ih, bias,
ic, seq_len,
pph); ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment