"vscode:/vscode.git/clone" did not exist on "bd1686548575a4ea00b6fd2e7c198547d5014e58"
Commit b34a8e60 authored by Nives Vukovic's avatar Nives Vukovic
Browse files

Implement layout attribute support for RNN operator

parent 0039b11a
...@@ -33,6 +33,29 @@ namespace migraphx { ...@@ -33,6 +33,29 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
void rnn_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() == 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
}
void rnn_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
}
struct parse_rnn : op_parser<parse_rnn> struct parse_rnn : op_parser<parse_rnn>
{ {
std::vector<op_desc> operators() const { return {{"RNN"}}; } std::vector<op_desc> operators() const { return {{"RNN"}}; }
...@@ -116,6 +139,12 @@ struct parse_rnn : op_parser<parse_rnn> ...@@ -116,6 +139,12 @@ struct parse_rnn : op_parser<parse_rnn>
clip = parser.parse_value(info.attributes.at("clip")).at<float>(); clip = parser.parse_value(info.attributes.at("clip")).at<float>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// if the number of arguments is less than 6, append // if the number of arguments is less than 6, append
// undefined operator to have 6 arguments // undefined operator to have 6 arguments
if(args.size() < 6) if(args.size() < 6)
...@@ -124,6 +153,11 @@ struct parse_rnn : op_parser<parse_rnn> ...@@ -124,6 +153,11 @@ struct parse_rnn : op_parser<parse_rnn>
args.insert(args.end(), (6 - args.size()), ins); args.insert(args.end(), (6 - args.size()), ins);
} }
if(layout != 0)
{
rnn_transpose_inputs(info, args);
}
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("rnn", auto hidden_states = info.add_instruction(make_op("rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -135,6 +169,11 @@ struct parse_rnn : op_parser<parse_rnn> ...@@ -135,6 +169,11 @@ struct parse_rnn : op_parser<parse_rnn>
// second output for the last hidden state // second output for the last hidden state
auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states);
if(layout != 0)
{
rnn_transpose_outputs(info, hidden_states, last_output);
}
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
}; };
......
...@@ -7223,6 +7223,130 @@ def reversesequence_time_test(): ...@@ -7223,6 +7223,130 @@ def reversesequence_time_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def rnn_bi_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 20, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 20, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 40])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'RNN',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0'],
outputs=['hs', 'output'],
activations=['tanh', 'sigmoid'],
clip=0,
direction='bidirectional',
hidden_size=20,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0], [hs, output])
@onnx_test()
def rnn_f_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 20, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 20, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 40])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'RNN',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0'],
outputs=['hs', 'output'],
activations=['tanh', 'sigmoid'],
clip=0,
direction='forward',
hidden_size=20,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0], [hs, output])
@onnx_test()
def rnn_f_5arg_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 20, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 20, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 40])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node('RNN',
inputs=['seq', 'w', 'r', 'bias', 'seq_len'],
outputs=['hs', 'output'],
activations=['tanh', 'sigmoid'],
clip=0,
direction='forward',
hidden_size=20,
layout=1)
return ([node], [seq, w, r, bias, seq_len], [hs, output])
@onnx_test()
def rnn_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 20, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 20, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 40])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'RNN',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0'],
outputs=['hs', 'output'],
activations=['tanh', 'sigmoid'],
clip=0,
direction='reverse',
hidden_size=20,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0], [hs, output])
@onnx_test()
def rnn_r_3arg_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 20, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 20, 20])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node('RNN',
inputs=['seq', 'w', 'r'],
outputs=['hs', 'output'],
activations=['tanh', 'sigmoid'],
clip=0,
direction='reverse',
hidden_size=20,
layout=1)
return ([node], [seq, w, r], [hs, output])
@onnx_test() @onnx_test()
def roialign_default_test(): def roialign_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8])
......
...@@ -100,6 +100,60 @@ TEST_CASE(rnn_test_bidirectional) ...@@ -100,6 +100,60 @@ TEST_CASE(rnn_test_bidirectional)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rnn_test_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
seq_len,
ih);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("rnn_bi_layout_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(rnn_test_one_direction) TEST_CASE(rnn_test_one_direction)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -241,6 +295,179 @@ TEST_CASE(rnn_test_one_direction) ...@@ -241,6 +295,179 @@ TEST_CASE(rnn_test_one_direction)
} }
} }
TEST_CASE(rnn_test_one_direction_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
// forward
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
seq_len,
ih);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("rnn_f_layout_test.onnx");
EXPECT(p == prog);
}
// reverse
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
seq_len,
ih);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("rnn_r_layout_test.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
und,
und,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("rnn_r_3arg_layout_test.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
seq_len,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("rnn_f_5arg_layout_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test) TEST_CASE(gru_test)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
......
...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test): ...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails # fails
# from OnnxBackendNodeModelTest # from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu') backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest # from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
......
...@@ -348,53 +348,57 @@ TEST_CASE(rnn_forward) ...@@ -348,53 +348,57 @@ TEST_CASE(rnn_forward)
} }
} }
TEST_CASE(rnn_reverse) TEST_CASE(rnn_forward_layout)
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
std::size_t seq_len = 2; std::size_t seq_len = 2;
std::size_t hidden_size = 4; std::size_t hidden_size = 4;
std::size_t input_size = 3; std::size_t input_size = 3;
std::size_t num_dirct = 1; std::size_t num_dirct = 1;
std::vector<float> w_data{-0.0296, std::vector<float> w_data{0.4691,
-0.1341, 0.3185,
0.1761, -0.2227,
-0.2325, 0.4423,
-0.0717, -0.0609,
0.1852, -0.2803,
0.2720, 0.1744,
0.1471, 0.3146,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> r_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049, 0.4049,
-0.3737, -0.3973,
-0.1051, -0.0890,
0.4482, -0.1636};
-0.2841};
std::vector<float> bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552}; std::vector<float> r_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> bias_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
std::vector<float> input(seq_len * batch_size * input_size, 0); std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0; input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0); migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
float clip = 0.0f; migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; float clip = 0.0f;
// concatenation of hidden states as program output // concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input}); auto seq = mm->add_literal(migraphx::literal{in_shape, input});
...@@ -404,12 +408,18 @@ TEST_CASE(rnn_reverse) ...@@ -404,12 +408,18 @@ TEST_CASE(rnn_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {}}, {"actv_func",
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}), {"clip", clip}}),
seq, seq,
w, w,
...@@ -417,31 +427,122 @@ TEST_CASE(rnn_reverse) ...@@ -417,31 +427,122 @@ TEST_CASE(rnn_reverse)
bias, bias,
und, und,
ih); ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({hs, lho});
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
auto res_hs = outputs.front();
auto res_lho = outputs.back();
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector<float> lho_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
std::vector<float> lho_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
std::vector<float> hs_data_gold{-0.29385301, EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
0.16796815, EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
0.51075965, }
0.40258689,
-0.13818839, {
0.44124447, migraphx::program p;
0.14365635, auto* mm = p.get_main_module();
0.14803654, auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input});
-0.0070999, auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
0.46251031, auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
-0.20639211, auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
0.37488942, auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
-0.0070999, migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
0.46251031, std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
-0.20639211, auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
0.37488942}; auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto last_out = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
last_out =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_out);
mm->add_return({out_hs, last_out});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_last_output = outputs.back();
std::vector<float> last_output_data;
std::vector<float> hs_data;
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, 0.03445704, 0.19167931, -0.3946827,
-0.30889652, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, -0.37144644, 0.31708236, 0.13104209, -0.18736027, -0.22276389,
0.44193283, -0.16477929, -0.11893477, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// rnn last output as program output
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -450,21 +551,873 @@ TEST_CASE(rnn_reverse) ...@@ -450,21 +551,873 @@ TEST_CASE(rnn_reverse)
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction( auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {}}, {"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto last_out = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
last_out =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_out);
mm->add_return({out_hs, last_out});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_last_output = outputs.back();
std::vector<float> last_output_data;
std::vector<float> hs_data;
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{0.377808,
0.610551,
0.551685,
-0.588848,
0.034457,
0.191679,
-0.394683,
-0.308897,
-0.371446,
0.317082,
0.131042,
-0.18736,
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// seq_len = 1
{
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
TEST_CASE(rnn_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> r_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// rnn last output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
p.compile(migraphx::make_target("ref"));
auto last_output = p.eval({}).back();
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635,
0.14803654, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{-0.293853,
0.167968,
0.51076,
0.402587,
-0.0070999,
0.46251,
-0.206392,
0.374889,
-0.0070999,
0.46251,
-0.206392,
0.374889,
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_reverse_layout)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> r_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942,
-0.13818839,
0.44124447,
0.14365635,
0.14803654,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// rnn last output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
p.compile(migraphx::make_target("ref"));
auto last_output = p.eval({}).back();
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.0070999, 0.46251031, -0.20639211,
0.37488942, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, -0.13818839, 0.44124447, 0.14365635, 0.14803654, -0.0070999,
0.46251031, -0.20639211, 0.37488942, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
std::vector<float> hs_data;
std::vector<float> last_output_data;
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{-0.293853,
0.167968,
0.51076,
0.402587,
-0.0070999,
0.46251,
-0.206392,
0.374889,
-0.0070999,
0.46251,
-0.206392,
0.374889,
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803,
0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636,
-0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852,
0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302};
std::vector<float> r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225,
-0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545,
-0.4981, 0.0616, 0.2528, -0.2333, 0.3973, 0.1593, -0.0388,
0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049,
-0.3737, -0.1051, 0.4482, -0.2841};
std::vector<float> bias_data{-0.4938,
0.4355,
-0.3186,
0.2094,
0.1037,
-0.1071,
0.4504,
-0.3990,
-0.3188,
0.1341,
-0.4446,
0.1389,
0.3117,
0.3664,
0.2352,
0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden state and last hs output for program outputs
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.back();
std::vector<float> hs_data;
arg_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> last_output_data;
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236,
0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689,
-0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931,
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942};
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// last rnn output for program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{1, 2};
auto sql = mm->add_literal(seq_len_s, len_data);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.back();
std::vector<float> hs_data;
std::vector<float> last_output_data;
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
0.377808, 0.610551, 0.551685, -0.588848, -0.371446, 0.317082, 0.131042, -0.18736,
-0.169158, 0.193817, 0.206679, 0.586097, -0.138188, 0.441244, 0.143656, 0.148037,
0, 0, 0, 0, -0.222764, 0.441933, -0.164779, -0.118935,
0, 0, 0, 0, -0.0070999, 0.46251, -0.206392, 0.374889};
std::vector<float> last_output_data_gold{0.377808,
0.610551,
0.551685,
-0.588848,
-0.222764,
0.441933,
-0.164779,
-0.118935,
-0.169158,
0.193817,
0.206679,
0.586097,
-0.138188,
0.441244,
0.143656,
0.148037};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// 4 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}), {"clip", clip}}),
seq, seq,
w, w,
r, r,
bias, bias);
und,
ih);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -473,7 +1426,15 @@ TEST_CASE(rnn_reverse) ...@@ -473,7 +1426,15 @@ TEST_CASE(rnn_reverse)
std::vector<float> last_output_data; std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301, std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815, 0.16796815,
0.51075965, 0.51075965,
0.40258689, 0.40258689,
...@@ -481,137 +1442,100 @@ TEST_CASE(rnn_reverse) ...@@ -481,137 +1442,100 @@ TEST_CASE(rnn_reverse)
0.44124447, 0.44124447,
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // 3 args
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input}); auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
auto out_hs = mm->add_instruction( mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {}}, {"actv_func",
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}), {"clip", clip}}),
seq, seq,
w, w,
r, r);
bias,
sql,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({}); auto last_output = p.eval({}).back();
std::vector<float> hs_data;
std::vector<float> last_output_data; std::vector<float> last_output_data;
auto arg_hs = outputs.front(); last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back();
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{
-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635,
0.14803654, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0};
std::vector<float> last_output_data_gold{-0.29385301, std::vector<float> last_output_data_gold{
0.16796815, 0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0.,
0.51075965, -0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0.,
0.40258689, 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
-0.13818839, 0., 0., 0., 0., 0., 0., 0., 0.};
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // concatenation of hidden state for program output
{ {
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input}); auto seq = mm->add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int32_t> len_data{2, 1}; mm->add_instruction(
auto sql = mm->add_literal(seq_len_s, len_data);
auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {}}, {"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}), {"clip", clip}}),
seq, seq,
w, w,
r, r,
bias, bias,
sql, und,
ih); ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
auto outputs = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
std::vector<float> last_output_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
auto arg_hs = outputs.front();
arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); });
auto arg_lho = outputs.back(); std::vector<float> hs_data_gold{0.37780784,
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); 0.61055139,
std::vector<float> hs_data_gold{-0.293853, 0.55168478,
0.167968, -0.5888475,
0.51076, -0.37144644,
0.402587, 0.31708236,
-0.0070999, 0.13104209,
0.46251, -0.18736027,
-0.206392, -0.16915828,
0.374889, 0.1938169,
0.20667936,
0.58609703,
-0.0070999, -0.0070999,
0.46251, 0.46251031,
-0.206392, -0.20639211,
0.374889, 0.37488942};
0,
0,
0,
0};
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
} }
TEST_CASE(rnn_bidirectional) TEST_CASE(rnn_bidirectional_layout)
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
std::size_t seq_len = 2; std::size_t seq_len = 2;
...@@ -650,8 +1574,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -650,8 +1574,8 @@ TEST_CASE(rnn_bidirectional)
input[0] = input[1] = 1.0; input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0); std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
...@@ -660,13 +1584,18 @@ TEST_CASE(rnn_bidirectional) ...@@ -660,13 +1584,18 @@ TEST_CASE(rnn_bidirectional)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input}); auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction( auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
...@@ -681,6 +1610,10 @@ TEST_CASE(rnn_bidirectional) ...@@ -681,6 +1610,10 @@ TEST_CASE(rnn_bidirectional)
und, und,
ih); ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({out_hs, lho}); mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -694,25 +1627,25 @@ TEST_CASE(rnn_bidirectional) ...@@ -694,25 +1627,25 @@ TEST_CASE(rnn_bidirectional)
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.29385301, 0.16796815,
0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689, 0.51075965, 0.40258689, 0.03445704, 0.19167931, -0.3946827, -0.30889652,
-0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.37144644, 0.31708236,
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, 0.13104209, -0.18736027, -0.13818839, 0.44124447, 0.14365635, 0.14803654,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, -0.22276389, 0.44193283, -0.16477929, -0.11893477, -0.0070999, 0.46251031,
-0.20639211, 0.37488942}; -0.20639211, 0.37488942};
std::vector<float> last_output_data_gold{0.03445704, std::vector<float> last_output_data_gold{0.03445704,
0.19167931, 0.19167931,
-0.3946827, -0.3946827,
-0.30889652, -0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301, -0.29385301,
0.16796815, 0.16796815,
0.51075965, 0.51075965,
0.40258689, 0.40258689,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.13818839, -0.13818839,
0.44124447, 0.44124447,
0.14365635, 0.14365635,
...@@ -735,6 +1668,10 @@ TEST_CASE(rnn_bidirectional) ...@@ -735,6 +1668,10 @@ TEST_CASE(rnn_bidirectional)
std::vector<int32_t> len_data{1, 2}; std::vector<int32_t> len_data{1, 2};
auto sql = mm->add_literal(seq_len_s, len_data); auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction( auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
...@@ -750,6 +1687,10 @@ TEST_CASE(rnn_bidirectional) ...@@ -750,6 +1687,10 @@ TEST_CASE(rnn_bidirectional)
sql, sql,
ih); ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({out_hs, lho}); mm->add_return({out_hs, lho});
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -763,27 +1704,27 @@ TEST_CASE(rnn_bidirectional) ...@@ -763,27 +1704,27 @@ TEST_CASE(rnn_bidirectional)
arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.377808, 0.610551, 0.551685, -0.588848, -0.371446, 0.317082, 0.131042, -0.18736, 0.377808, 0.610551, 0.551685, -0.588848, -0.169158, 0.193817, 0.206679, 0.586097,
-0.169158, 0.193817, 0.206679, 0.586097, -0.138188, 0.441244, 0.143656, 0.148037, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, -0.222764, 0.441933, -0.164779, -0.118935, -0.371446, 0.317082, 0.131042, -0.18736, -0.138188, 0.441244, 0.143656, 0.148037,
0, 0, 0, 0, -0.0070999, 0.46251, -0.206392, 0.374889}; -0.222764, 0.441933, -0.164779, -0.118935, -0.0070999, 0.46251, -0.206392, 0.374889};
std::vector<float> last_output_data_gold{0.377808, std::vector<float> last_output_data_gold{0.377808,
0.610551, 0.610551,
0.551685, 0.551685,
-0.588848, -0.588848,
-0.222764,
0.441933,
-0.164779,
-0.118935,
-0.169158, -0.169158,
0.193817, 0.193817,
0.206679, 0.206679,
0.586097, 0.586097,
-0.222764,
0.441933,
-0.164779,
-0.118935,
-0.138188, -0.138188,
0.441244, 0.441244,
0.143656, 0.143656,
0.148037}; 0.148037};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
...@@ -797,6 +1738,9 @@ TEST_CASE(rnn_bidirectional) ...@@ -797,6 +1738,9 @@ TEST_CASE(rnn_bidirectional)
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto out_hs = mm->add_instruction( auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
...@@ -810,8 +1754,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -810,8 +1754,8 @@ TEST_CASE(rnn_bidirectional)
w, w,
r, r,
bias); bias);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto last_output = p.eval({}).back(); auto last_output = p.eval({}).back();
...@@ -822,19 +1766,18 @@ TEST_CASE(rnn_bidirectional) ...@@ -822,19 +1766,18 @@ TEST_CASE(rnn_bidirectional)
0.19167931, 0.19167931,
-0.3946827, -0.3946827,
-0.30889652, -0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301, -0.29385301,
0.16796815, 0.16796815,
0.51075965, 0.51075965,
0.40258689, 0.40258689,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.13818839, -0.13818839,
0.44124447, 0.44124447,
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
...@@ -846,7 +1789,10 @@ TEST_CASE(rnn_bidirectional) ...@@ -846,7 +1789,10 @@ TEST_CASE(rnn_bidirectional)
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction( std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -858,6 +1804,9 @@ TEST_CASE(rnn_bidirectional) ...@@ -858,6 +1804,9 @@ TEST_CASE(rnn_bidirectional)
seq, seq,
w, w,
r); r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto last_output = p.eval({}).back(); auto last_output = p.eval({}).back();
...@@ -865,10 +1814,11 @@ TEST_CASE(rnn_bidirectional) ...@@ -865,10 +1814,11 @@ TEST_CASE(rnn_bidirectional)
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0., 0.6570473, 0.36392266, 0.45342238, -0.45127486, -0.16225325, -0.29515147, 0.39617197,
-0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0., 0.27068236, 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0.,
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.}; 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
...@@ -878,7 +1828,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -878,7 +1828,7 @@ TEST_CASE(rnn_bidirectional)
seq_len = 1; seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0); std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0; input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape_1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -888,7 +1838,12 @@ TEST_CASE(rnn_bidirectional) ...@@ -888,7 +1838,12 @@ TEST_CASE(rnn_bidirectional)
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto out_hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"rnn", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -901,6 +1856,9 @@ TEST_CASE(rnn_bidirectional) ...@@ -901,6 +1856,9 @@ TEST_CASE(rnn_bidirectional)
bias, bias,
und, und,
ih); ih);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -910,19 +1868,18 @@ TEST_CASE(rnn_bidirectional) ...@@ -910,19 +1868,18 @@ TEST_CASE(rnn_bidirectional)
0.61055139, 0.61055139,
0.55168478, 0.55168478,
-0.5888475, -0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.16915828, -0.16915828,
0.1938169, 0.1938169,
0.20667936, 0.20667936,
0.58609703, 0.58609703,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.0070999, -0.0070999,
0.46251031, 0.46251031,
-0.20639211, -0.20639211,
0.37488942}; 0.37488942};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_4args_layout : verify_program<test_rnn_4args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 5;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_bi_3args_layout : verify_program<test_rnn_bi_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_bidirectional_layout : verify_program<test_rnn_bidirectional_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto output = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_forward_layout : verify_program<test_rnn_forward_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
mm->add_return({hs, lho});
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_reverse : verify_program<test_rnn_reverse>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_rnn_sql_1_layout : verify_program<test_rnn_sql_1_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
std::vector<int> sl_data{5, 7};
auto sql = mm->add_literal(migraphx::literal{s_shape, sl_data});
auto ih = mm->add_parameter("ih", ih_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
auto hs = mm->add_instruction(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
sql,
ih);
auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
last_hs =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_hs);
mm->add_return({hs, last_hs});
return p;
}
std::string section() const { return "rnn"; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment