#include #include #include #include #include #include #include #include "test.hpp" TEST_CASE(rnn_test_bidirectional) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto out_hs = p.add_instruction(migraphx::op::rnn{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); EXPECT(p == prog); } TEST_CASE(rnn_test_one_direction) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 1; // num directions float clip = 0.0f; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; // forward { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto out_hs = p.add_instruction(migraphx::op::rnn{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); EXPECT(p == prog); } // reverse { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto out_hs = p.add_instruction(migraphx::op::rnn{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); EXPECT(p == prog); } // 3 argumments { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::rnn{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); EXPECT(p == prog); } // 5 argumments { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::rnn{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, seq_len, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); EXPECT(p == prog); } } TEST_CASE(gru_test) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; // forward { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); EXPECT(p == prog); } // reverse { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); EXPECT(p == prog); } // bidirectional { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::relu{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); EXPECT(p == prog); } } TEST_CASE(gru_test_args) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; // 3 arguments { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); EXPECT(p == prog); } // 4 arguments { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); EXPECT(p == prog); } // 5 arguments { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); EXPECT(p == prog); } } TEST_CASE(gru_test_actv_funcs) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; // bidirection, 0 actv function { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction( migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); EXPECT(p == prog); } // bidirection, 1 actv function { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction( migraphx::op::gru{ hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); EXPECT(p == prog); } // bidirection, 2 actv functions { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); EXPECT(p == prog); } // bidirection, 3 actv functions { nd = 2; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction( migraphx::op::gru{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); EXPECT(p == prog); } // forward, 0 actv function { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); EXPECT(p == prog); } // reverse, 1 actv function { nd = 1; migraphx::program p; auto seq = p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto out_hs = p.add_instruction( migraphx::op::gru{ hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, seq_len, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); EXPECT(p == prog); } } TEST_CASE(lstm_forward) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 1; // num directions float clip = 0.0f; int input_forget = 1; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto pph = p.add_parameter("pph", pph_shape); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, pph); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx"); EXPECT(p == prog); } // 3 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx"); EXPECT(p == prog); } // 4 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx"); EXPECT(p == prog); } // 5 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, seq_len, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx"); EXPECT(p == prog); } // 6 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, seq_len, ih, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx"); EXPECT(p == prog); } // 7 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx"); EXPECT(p == prog); } } // activation functions TEST_CASE(lstm_forward_actv_func) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 1; // num directions float clip = 0.0f; int input_forget = 1; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; // no activation function specified { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx"); EXPECT(p == prog); } // 1 activation function specified { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx"); EXPECT(p == prog); } // 2 activation function specified { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, bias, seq_len, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); EXPECT(p == prog); } } TEST_CASE(lstm_reverse) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 1; // num directions float clip = 0.0f; int input_forget = 1; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto pph = p.add_parameter("pph", pph_shape); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, pph); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx"); EXPECT(p == prog); } // 5 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, input_forget}, seq, w, r, bias, seq_len, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx"); EXPECT(p == prog); } // no activation function specified { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx"); EXPECT(p == prog); } } TEST_CASE(lstm_bidirectional) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; int input_forget = 1; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto pph = p.add_parameter("pph", pph_shape); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, pph); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx"); EXPECT(p == prog); } // 3 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx"); EXPECT(p == prog); } // 4 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx"); EXPECT(p == prog); } // 5 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx"); EXPECT(p == prog); } // 6 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, ih, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx"); EXPECT(p == prog); } // 7 args { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx"); EXPECT(p == prog); } } TEST_CASE(lstm_bi_actv_funcs) { std::size_t sl = 5; // sequence len std::size_t bs = 3; // batch size std::size_t hs = 20; // hidden size std::size_t is = 10; // input size std::size_t nd = 2; // num directions float clip = 0.0f; int input_forget = 1; migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}}; migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; // 0 activation function { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::lstm{ hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx"); EXPECT(p == prog); } // 1 activation function { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx"); EXPECT(p == prog); } // 2 activation functions { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx"); EXPECT(p == prog); } // 4 activation functions { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, ih, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx"); EXPECT(p == prog); } // 5 activation functions { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto bias = p.add_parameter("bias", bias_shape); auto seq_len = p.add_parameter("seq_len", sl_shape); auto ih = p.add_parameter("h0", ih_shape); auto ic = p.add_parameter("c0", ih_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, bias, seq_len, ih, ic, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx"); EXPECT(p == prog); } // 6 activation functions { migraphx::program p; auto seq = p.add_parameter("seq", seq_shape); auto w = p.add_parameter("w", w_shape); auto r = p.add_parameter("r", r_shape); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, seq, w, r, und, und, und, und, und); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx"); EXPECT(p == prog); } } int main(int argc, const char* argv[]) { test::run(argc, argv); }