Commit eb0d8fee authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into driver

parents 65ef35cd 0d796941
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
TEST_CASE(rnn_test_bidirectional)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
}
TEST_CASE(rnn_test_one_direction)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
// forward
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
}
// bidirectional
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
}
// 4 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
}
// 5 arguments
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
}
// bidirection, 1 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
}
// bidirection, 2 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
}
// bidirection, 3 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
}
// forward, 0 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
}
// reverse, 1 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_forward)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog);
}
// 3 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
}
// 4 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog);
}
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog);
}
// 6 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog);
}
// 7 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog);
}
}
// activation functions
TEST_CASE(lstm_forward_actv_func)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
// no activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog);
}
// 1 activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog);
}
// 2 activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_reverse)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog);
}
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog);
}
// no activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog);
}
// 3 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog);
}
// 4 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog);
}
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog);
}
// 6 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog);
}
// 7 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog);
}
// 1 activation function
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog);
}
// 2 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog);
}
// 4 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog);
}
// 5 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog);
}
// 6 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv.onnx");
......@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool)
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
......@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
......@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
......@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
......@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
......@@ -392,8 +392,8 @@ TEST_CASE(unknown_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"}, l2);
auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog);
......@@ -466,599 +466,12 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog);
}
TEST_CASE(rnn_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirectional
{
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
}
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
}
// bidirectional
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
}
// 4 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
}
// 5 arguments
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
}
// bidirection, 1 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
}
// bidirection, 2 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
}
// bidirection, 3 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
}
// forward, 0 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
}
// reverse, 1 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{1}, l0);
p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx");
EXPECT(p == prog);
......@@ -1108,6 +521,15 @@ TEST_CASE(constant_test)
EXPECT(p == prog);
}
TEST_CASE(constant_test_scalar)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_fill_test)
{
{
......@@ -1144,18 +566,140 @@ TEST_CASE(gemm_test)
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f;
p.add_instruction(migraphx::op::dot{alpha}, t0, t1);
auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_brcst)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 = p.add_instruction(migraphx::op::multibroadcast{out_lens}, l2);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = migraphx::parse_onnx("gemm_test_ex1.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = migraphx::parse_onnx("matmul_vv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = migraphx::parse_onnx("matmul_vm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl;
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_vbm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_mv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_mv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = migraphx::parse_onnx("matmul_bmbm.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
......@@ -1186,7 +730,9 @@ TEST_CASE(group_conv_test)
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
migraphx::parse_onnx("group_conv_test.onnx");
auto prog = migraphx::parse_onnx("group_conv_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_test)
......@@ -1194,13 +740,14 @@ TEST_CASE(pad_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
migraphx::parse_onnx("pad_test.onnx");
auto prog = migraphx::parse_onnx("pad_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
......@@ -1208,7 +755,53 @@ TEST_CASE(lrn_test)
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx");
auto prog = migraphx::parse_onnx("lrn_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_fp16_test)
{
migraphx::program p;
auto l0 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("add_fp16_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(logsoftmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(no_pad_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_onnx("no_pad_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_onnx("clip_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment