Unverified Commit b2a40ea6 authored by nives-vukovic's avatar nives-vukovic Committed by GitHub
Browse files

Add layout attribute support for LSTM operator (#2343)

Since ONNX opset version 14, layout attribute has been introduced to LSTM operator which allows two predefined layouts for input and output shapes.  Add corresponding reference, onnx, and verify tests.
parent 45ccd757
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
...@@ -4484,6 +4484,177 @@ def lrn_test(): ...@@ -4484,6 +4484,177 @@ def lrn_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def lstm_bi_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_bi_layout_last_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_hs_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs])
@onnx_test()
def lstm_r_layout_hs_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', 'output', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout])
@onnx_test() @onnx_test()
def matmul_bmbm_test(): def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
......
...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward) ...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
} }
} }
TEST_CASE(lstm_forward_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs and last output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx");
EXPECT(p == prog);
}
// 8 args, cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
// activation functions // activation functions
TEST_CASE(lstm_forward_actv_func) TEST_CASE(lstm_forward_actv_func)
{ {
...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse) ...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
} }
} }
TEST_CASE(lstm_reverse_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
auto prog = optimize_onnx("lstm_r_layout_test.onnx");
EXPECT(p == prog);
}
// 8 args, last and cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional) ...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
} }
} }
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs) TEST_CASE(lstm_bi_actv_funcs)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
......
...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test): ...@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails # fails
# from OnnxBackendNodeModelTest # from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu') backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu') backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest # from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
......
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_3args_layout : verify_program<test_lstm_bidirct_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_last_layout : verify_program<test_lstm_bidirct_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_hs_layout : verify_program<test_lstm_forward_hs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_last_layout : verify_program<test_lstm_forward_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
len,
ih,
ic,
pph);
auto last_output =
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_cell_layout : verify_program<test_lstm_reverse_3args_cell_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_layout : verify_program<test_lstm_reverse_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_three_outputs_layout : verify_program<test_lstm_three_outputs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r);
auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
last_hs =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_hs);
last_cell =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
mm->add_return({hs, last_hs, last_cell});
return p;
}
std::string section() const { return "rnn"; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment