Unverified Commit b2a40ea6 authored by nives-vukovic's avatar nives-vukovic Committed by GitHub
Browse files

Add layout attribute support for LSTM operator (#2343)

Since ONNX opset version 14, layout attribute has been introduced to LSTM operator which allows two predefined layouts for input and output shapes.  Add corresponding reference, onnx, and verify tests.
parent 45ccd757
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
...@@ -4484,6 +4484,177 @@ def lrn_test(): ...@@ -4484,6 +4484,177 @@ def lrn_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def lstm_bi_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_bi_layout_last_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_hs_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs])
@onnx_test()
def lstm_r_layout_hs_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', 'output', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout])
@onnx_test() @onnx_test()
def matmul_bmbm_test(): def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
......
...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward) ...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
} }
} }
TEST_CASE(lstm_forward_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs and last output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx");
EXPECT(p == prog);
}
// 8 args, cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
// activation functions // activation functions
TEST_CASE(lstm_forward_actv_func) TEST_CASE(lstm_forward_actv_func)
{ {
...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse) ...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
} }
} }
TEST_CASE(lstm_reverse_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
auto prog = optimize_onnx("lstm_r_layout_test.onnx");
EXPECT(p == prog);
}
// 8 args, last and cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional) ...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
} }
} }
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs) TEST_CASE(lstm_bi_actv_funcs)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
......
...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test): ...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails # fails
# from OnnxBackendNodeModelTest # from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu') backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu') backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest # from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
......
...@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward) ...@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
} }
} }
TEST_CASE(lstm_forward_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.0742487, -0.0800085, 0.259897,
0.0670196, -0.00532985, 0.0440265, 0.29654, -0.0463156, -0.0847427, 0.0874114,
0.304256, -0.0585745, 0.138193, -0.0322939, -0.0891815, 0.15773, 0.184266,
0.0610048, -0.138041, 0.0963885, 0.0498799, 0.125772, 0.0533032, -0.131413,
-0.0223018, 0.131113, 0.135643, -0.056620, 0.19139, -0.127708, -0.409371,
-0.136186, 0.0213755, -0.146027, -0.0324509, -0.0620429, 0.0988431, -0.018085,
-0.159434, 0.030266, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.111454,
0.247794,
0.471087,
-0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_forward_more) TEST_CASE(lstm_forward_more)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
...@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
} }
} }
TEST_CASE(lstm_reverse) TEST_CASE(lstm_forward_more_layout)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
std::size_t seq_len = 4; std::size_t seq_len = 4;
...@@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse) ...@@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse)
std::size_t input_size = 3; std::size_t input_size = 3;
std::size_t num_dirct = 1; std::size_t num_dirct = 1;
std::vector<float> w_data{ std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{ std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
0.3987, -0.1687, -0.0032, -0.1038}; -0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.4386, 0.4208, 0.0717, 0.3789}; -0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.0786602, -0.0613048, 0.179592,
-0.071286, -0.102509, -0.0372696, 0.252296, -0.144544, -0.165194, -0.0372928,
0.273786, -0.100877, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.074206,
0.0124086, -0.139544, 0.108016, 0.00496085, 0.0662588, -0.048577, -0.187329,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.101585, 0.0687269, -0.161725,
-0.25617, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.0855831, -0.0171894,
-0.140202, 0.0828391, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, 8 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.186991, -0.0624168, 0.205513,
0.0836373, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.058052, 0.0795391,
0.266617, -0.0128746, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.421857,
0.0459771, -0.144955, 0.0720673, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.022618, -0.121195, -0.4065,
-0.252054, -0.0300906, -0.0890598, -0.135266, -0.0413375, 0.103489, 0.0142918,
-0.123408, 0.0401075, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// seq_len = 1
{
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.079753,
-0.289854,
0.160043,
0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
TEST_CASE(lstm_reverse)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// variable sequence lengths
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_reverse_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{ std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289, std::vector<float> ih_data{1.5289,
1.0986, 1.0986,
...@@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse) ...@@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse)
-0.5811, -0.5811,
0.2166}; 0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f; float clip = 0.0f;
// reverse, concatenation of hidden states as program output // reverse, concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
...@@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse) ...@@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse) ...@@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse)
ih, ih,
ic, ic,
pph); pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, -0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, -0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, 0.177273, -0.0774616, 0.789732, 0.128538, 0.20909, 0.0553812, 0.928866,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, 0.130438,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, 0.946669, 0.0868676, 0.044508, -0.373961, -0.224905, 0.32421, 0.344048,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895, -0.0252804, 0.267356,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse) ...@@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f); std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len)); std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data); auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse) ...@@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse)
ih, ih,
ic, ic,
pph); pph);
p.compile(migraphx::make_target("ref")); std::vector<int64_t> perm_hid{2, 0, 1, 3};
auto hs_concat = p.eval({}).back(); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
std::vector<float> output_data; p.compile(migraphx::make_target("ref"));
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); auto hs_concat = p.eval({}).back();
std::vector<float> output_data_gold{ std::vector<float> output_data;
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, std::vector<float> output_data_gold{
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, -0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, -0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.177273, -0.0774616, 0.0, 0.0, 0.0, 0.0, 0.0,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0, 0.0, 0.0, 0.0, 0.789732, 0.128538, 0.20909, 0.0553812,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.928866, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.130438, 0.946669, 0.0868676, 0.044508, -0.373961, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.224905,
0.0, 0.0}; 0.32421, 0.344048, 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895,
-0.0252804, 0.267356, 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211,
-0.161537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse) ...@@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse)
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1}; std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data); auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse) ...@@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse)
ih, ih,
ic, ic,
pph); pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449, -0.126517, 0.0359124, 0.107453, -0.0617278, -0.168327, 0.00023761, 0.167567,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761, -0.0621982, -0.204545, 0.0146403, 0.210057, 0.0296268, 0, 0,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0, 0, 0, 0.911307, 0.11468, 0.114449, 0.0196755, 0.96657,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268, 0.0755112, 0.0620917, -0.264845, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.102969, 0.295872, 0.515859,
0, 0, 0, 0, 0, 0, 0, 0.246501, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse) ...@@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
...@@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse) ...@@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse)
seq, seq,
w, w,
r); r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
...@@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv)
0.8069, 0.8069,
-0.5141}; -0.5141};
std::vector<float> pph_data{-0.8271, std::vector<float> pph_data{-0.8271,
-0.5683, -0.5683,
0.4562, 0.4562,
-1.2545, -1.2545,
1.2729, 1.2729,
-0.4082, -0.4082,
-0.4392, -0.4392,
-0.9406, -0.9406,
0.7794, 0.7794,
1.8194, 1.8194,
-0.5811, -0.5811,
0.2166}; 0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
...@@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv)
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction( mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value( migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})}, migraphx::make_op("tanh"),
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, 2 actv functions // last hidden state as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); mm->add_instruction(
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{ migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, migraphx::make_op("tanh"),
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123, std::vector<float> output_data_gold{
-0.37531, -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
-0.12943, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.00798307, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
-0.133882, 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
-0.0251383, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
0.0486486, -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
-0.0220606, 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
0.292495, -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.233866, 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
0.48646, -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.481844}; 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output // sequence length is 1, contenation of hidden state as program output
{ {
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction( mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
...@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
...@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv) ...@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351, std::vector<float> output_data_gold{
-0.0471426, -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0905753, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
0.01506, -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
0.059797, -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional_layout)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
std::size_t seq_len = 4; std::size_t seq_len = 4;
...@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional) ...@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
-0.4386, 0.4208, 0.0717, 0.3789}; -0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{ std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, 0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, 0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; 0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, 0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, 2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; -1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
...@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional) ...@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406, -1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166}; 0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output // concatenation of hidden states as program output
...@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional) ...@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional) ...@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
ih, ih,
ic, ic,
pph); pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, 0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, -0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.178681, -0.266999, 0.0459032, 0.0414126, 0.272303, 0.0393149, -0.182201,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, -0.185038, -0.026845, 0.177273, -0.0774616, 0.294074, -0.0319677, -0.0955337,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812, 0.421857, 0.0459771,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626, -0.0432316, 0.218258,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565, 0.269741, 0.130438,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669, 0.0868676, 0.044508,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, -0.373961, 0.022618, -0.121195, -0.4065, -0.252054, -0.224905, 0.32421,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.344048, 0.271694, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.063456,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, 0.148524, 0.05108, -0.0234895, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, -0.0252804, 0.267356, 0.146353, 0.0789186, 0.187761, 0.0501726, -0.121584,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; 0.0606723, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional) ...@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
...@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional) ...@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
ih, ih,
ic, ic,
pph); pph);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, -0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; 0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional) ...@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined")); auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction( auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
...@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional) ...@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
ih, ih,
ic, ic,
pph); pph);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, -0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; 0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional) ...@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional) ...@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
seq, seq,
w, w,
r); r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, -0.0327039, -0.0543852, 0.114378, -0.0768855, -0.162851, -0.102647, -0.113827,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, -0.142818, -0.0786602, -0.0613048, 0.179592, -0.071286, -0.123496, -0.153616,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, -0.032874, -0.195349, -0.102509, -0.0372696, 0.252296, -0.144544, -0.1073,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, -0.150145, 0.015065, -0.192699, -0.165194, -0.0372928, 0.273786, -0.100877,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, -0.021205, -0.125423, 0.0206439, -0.187097, 0.0319021, -0.00298698, -0.0623361,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, 0.0598866, 0.0513685, 0.0547876, 0.0201981, -0.00808453, 0.074206, 0.0124086,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, -0.139544, 0.108016, 0.0192675, -0.108636, 0.098927, -0.140733, 0.00496085,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0662588, -0.048577, -0.187329, -0.112764, -0.120496, 0.155754, 0.148256,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, -0.0458544, -0.0401315, 0.0737483, -0.064505, -0.0051453, -0.0767618, -0.0735348,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, -0.0826436, 0.101585, 0.0687269, -0.161725, -0.25617, -0.00520328, 0.0945081,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, 0.264123, 0.410805, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.162602,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0143099, -0.0455534, 0.0151574, 0.0855831, -0.0171894, -0.140202, 0.0828391,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, 0.208491, 0.348432, 0.0291103, 0.230275, 0.136898, 0.00160891, -0.184812,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; 0.147774, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
...@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional) ...@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
seq_len = 1; seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
std::vector<float> input_data1{ std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional) ...@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
seq, seq,
w, w,
r); r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back(); auto hs_concat = p.eval({}).back();
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0327039, -0.0543852, 0.114378, -0.0768855, -0.104351, -0.0471426,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0905753, 0.01506, 0.0319021, -0.00298698, -0.0623361, 0.0598866,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, 0.059797, 0.104239, -0.0266768, 0.0727547, 0.101585, 0.0687269,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; -0.161725, -0.25617, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) ...@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
} }
} }
TEST_CASE(lstm_bidirectional_var_seq_lens_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986,
0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968,
0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962,
0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998,
0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831,
2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498,
-1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
{
std::vector<int> sl_data{1, 2, 3};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({out_hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.at(1);
auto arg_lco = outputs.at(2);
std::vector<float> output_data;
std::vector<float> last_output_data;
std::vector<float> last_cell_data;
arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); });
arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804,
0.0745128, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.911307, 0.11468, 0.114449, 0.0196755, 0.421857, 0.0459771,
-0.144955, 0.0720673, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0.022618, -0.121195, -0.4065, -0.252054, -0.262807, 0.275286,
0.358395, 0.266267, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.128254,
0.125398, 0.0665142, -0.163651, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.0644683, 0.371512, 0.212431, -0.116131, 0, 0, 0,
0, 0, 0, 0, 0};
std::vector<float> last_output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804, 0.0745128,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.911307, 0.11468, 0.114449, 0.0196755,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.262807, 0.275286, 0.358395, 0.266267};
std::vector<float> last_cell_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, -0.326822, 0.301121, 0.219523, 0.415242,
0.737121, 0.134902, -0.303595, 0.241948, 2.08242, 0.442513, 0.187127, 0.0577626,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto res_hs = outputs.at(0);
auto res_lho = outputs.at(1);
auto res_lco = outputs.at(2);
std::vector<float> hs_data;
std::vector<float> lho_data;
std::vector<float> lco_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
-0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.178681, -0.266999, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.182201,
-0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
-0.185038, -0.026845, 0.177273, -0.0774616, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0.294074,
-0.0319677, -0.0955337, 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626,
-0.0432316, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565,
0.269741, 0.130438, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669,
0.0868676, 0.044508, -0.373961, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0.022618, -0.121195,
-0.4065, -0.252054, -0.224905, 0.32421, 0.344048, 0.271694, -0.0300906,
-0.0890598, -0.135266, -0.0413375, -0.063456, 0.148524, 0.05108, -0.0234895,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.0252804, 0.267356, 0.146353,
0.0789186, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.0681467, 0.382748,
0.230211, -0.161537, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0};
std::vector<float> lho_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
std::vector<float> lco_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func) TEST_CASE(lstm_bidirectional_actv_func)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_3args_layout : verify_program<test_lstm_bidirct_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_last_layout : verify_program<test_lstm_bidirct_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_hs_layout : verify_program<test_lstm_forward_hs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_last_layout : verify_program<test_lstm_forward_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
len,
ih,
ic,
pph);
auto last_output =
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_cell_layout : verify_program<test_lstm_reverse_3args_cell_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_layout : verify_program<test_lstm_reverse_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_three_outputs_layout : verify_program<test_lstm_three_outputs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r);
auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
last_hs =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_hs);
last_cell =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
mm->add_return({hs, last_hs, last_cell});
return p;
}
std::string section() const { return "rnn"; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment