Commit 61775eab authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_fp8

parents a5c38ebe e7e5ba23
......@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
......
......@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm>
{
std::vector<op_desc> operators() const { return {{"LSTM"}}; }
......@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
}
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
......@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins);
}
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
......@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output};
}
};
......
......@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......
......@@ -4484,6 +4484,177 @@ def lrn_test():
return ([node], [x], [y])
@onnx_test()
def lstm_bi_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_bi_layout_last_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_hs_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs])
@onnx_test()
def lstm_r_layout_hs_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', 'output', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout])
@onnx_test()
def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
......
......@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE(lstm_forward_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs and last output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx");
EXPECT(p == prog);
}
// 8 args, cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
// activation functions
TEST_CASE(lstm_forward_actv_func)
{
......@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
}
}
TEST_CASE(lstm_reverse_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
auto prog = optimize_onnx("lstm_r_layout_test.onnx");
EXPECT(p == prog);
}
// 8 args, last and cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t sl = 5; // sequence len
......@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
}
}
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs)
{
std::size_t sl = 5; // sequence len
......
......@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
......
......@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b_lit = mm->add_literal(5.0);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale_a_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_a_lit);
auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(10.0f);
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
......@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto s1_mb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
......@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw_lit = else_mod->add_literal(1.66667f);
auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto ssw_mb = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
......
......@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE(lstm_forward_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.0742487, -0.0800085, 0.259897,
0.0670196, -0.00532985, 0.0440265, 0.29654, -0.0463156, -0.0847427, 0.0874114,
0.304256, -0.0585745, 0.138193, -0.0322939, -0.0891815, 0.15773, 0.184266,
0.0610048, -0.138041, 0.0963885, 0.0498799, 0.125772, 0.0533032, -0.131413,
-0.0223018, 0.131113, 0.135643, -0.056620, 0.19139, -0.127708, -0.409371,
-0.136186, 0.0213755, -0.146027, -0.0324509, -0.0620429, 0.0988431, -0.018085,
-0.159434, 0.030266, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.111454,
0.247794,
0.471087,
-0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_forward_more)
{
std::size_t batch_size = 3;
......@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
}
}
TEST_CASE(lstm_reverse)
TEST_CASE(lstm_forward_more_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
......@@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse)
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.0786602, -0.0613048, 0.179592,
-0.071286, -0.102509, -0.0372696, 0.252296, -0.144544, -0.165194, -0.0372928,
0.273786, -0.100877, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.074206,
0.0124086, -0.139544, 0.108016, 0.00496085, 0.0662588, -0.048577, -0.187329,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.101585, 0.0687269, -0.161725,
-0.25617, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.0855831, -0.0171894,
-0.140202, 0.0828391, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, 8 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.186991, -0.0624168, 0.205513,
0.0836373, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.058052, 0.0795391,
0.266617, -0.0128746, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.421857,
0.0459771, -0.144955, 0.0720673, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.022618, -0.121195, -0.4065,
-0.252054, -0.0300906, -0.0890598, -0.135266, -0.0413375, 0.103489, 0.0142918,
-0.123408, 0.0401075, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// seq_len = 1
{
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.079753,
-0.289854,
0.160043,
0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
TEST_CASE(lstm_reverse)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// variable sequence lengths
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_reverse_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
......@@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse)
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
......@@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
-0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
-0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.177273, -0.0774616, 0.789732, 0.128538, 0.20909, 0.0553812, 0.928866,
0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, 0.130438,
0.946669, 0.0868676, 0.044508, -0.373961, -0.224905, 0.32421, 0.344048,
0.271694, -0.063456, 0.148524, 0.05108, -0.0234895, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
-0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.177273, -0.0774616, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.789732, 0.128538, 0.20909, 0.0553812,
0.928866, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741,
0.130438, 0.946669, 0.0868676, 0.044508, -0.373961, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.224905,
0.32421, 0.344048, 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895,
-0.0252804, 0.267356, 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211,
-0.161537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse)
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
-0.126517, 0.0359124, 0.107453, -0.0617278, -0.168327, 0.00023761, 0.167567,
-0.0621982, -0.204545, 0.0146403, 0.210057, 0.0296268, 0, 0,
0, 0, 0.911307, 0.11468, 0.114449, 0.0196755, 0.96657,
0.0755112, 0.0620917, -0.264845, 0, 0, 0, 0,
0, 0, 0, 0, -0.102969, 0.295872, 0.515859,
0.246501, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
......@@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse)
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
......@@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv)
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
......@@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv)
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
// last hidden state as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
// sequence length is 1, contenation of hidden state as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto* mm = p.get_main_module();
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
......@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
......@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
......@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986,
0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968,
0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962,
0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998,
0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831,
2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498,
-1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
......@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
......@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
-0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.178681, -0.266999, 0.0459032, 0.0414126, 0.272303, 0.0393149, -0.182201,
-0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.789732, 0.128538, 0.20909, 0.0553812, 0.421857, 0.0459771,
-0.144955, 0.0720673, 0.928866, 0.113685, 0.220626, -0.0432316, 0.218258,
0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565, 0.269741, 0.130438,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669, 0.0868676, 0.044508,
-0.373961, 0.022618, -0.121195, -0.4065, -0.252054, -0.224905, 0.32421,
0.344048, 0.271694, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.063456,
0.148524, 0.05108, -0.0234895, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.0252804, 0.267356, 0.146353, 0.0789186, 0.187761, 0.0501726, -0.121584,
0.0606723, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
......@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
-0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
......@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
-0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.162851, -0.102647, -0.113827,
-0.142818, -0.0786602, -0.0613048, 0.179592, -0.071286, -0.123496, -0.153616,
-0.032874, -0.195349, -0.102509, -0.0372696, 0.252296, -0.144544, -0.1073,
-0.150145, 0.015065, -0.192699, -0.165194, -0.0372928, 0.273786, -0.100877,
-0.021205, -0.125423, 0.0206439, -0.187097, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.0513685, 0.0547876, 0.0201981, -0.00808453, 0.074206, 0.0124086,
-0.139544, 0.108016, 0.0192675, -0.108636, 0.098927, -0.140733, 0.00496085,
0.0662588, -0.048577, -0.187329, -0.112764, -0.120496, 0.155754, 0.148256,
-0.0458544, -0.0401315, 0.0737483, -0.064505, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.101585, 0.0687269, -0.161725, -0.25617, -0.00520328, 0.0945081,
0.264123, 0.410805, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.162602,
0.0143099, -0.0455534, 0.0151574, 0.0855831, -0.0171894, -0.140202, 0.0828391,
0.208491, 0.348432, 0.0291103, 0.230275, 0.136898, 0.00160891, -0.184812,
0.147774, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
......@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
migraphx::program p;
auto* mm = p.get_main_module();
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
......@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.104351, -0.0471426,
-0.0905753, 0.01506, 0.0319021, -0.00298698, -0.0623361, 0.0598866,
0.059797, 0.104239, -0.0266768, 0.0727547, 0.101585, 0.0687269,
-0.161725, -0.25617, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
}
}
TEST_CASE(lstm_bidirectional_var_seq_lens_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986,
0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968,
0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962,
0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998,
0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831,
2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498,
-1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
{
std::vector<int> sl_data{1, 2, 3};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({out_hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.at(1);
auto arg_lco = outputs.at(2);
std::vector<float> output_data;
std::vector<float> last_output_data;
std::vector<float> last_cell_data;
arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); });
arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804,
0.0745128, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.911307, 0.11468, 0.114449, 0.0196755, 0.421857, 0.0459771,
-0.144955, 0.0720673, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0.022618, -0.121195, -0.4065, -0.252054, -0.262807, 0.275286,
0.358395, 0.266267, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.128254,
0.125398, 0.0665142, -0.163651, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.0644683, 0.371512, 0.212431, -0.116131, 0, 0, 0,
0, 0, 0, 0, 0};
std::vector<float> last_output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804, 0.0745128,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.911307, 0.11468, 0.114449, 0.0196755,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.262807, 0.275286, 0.358395, 0.266267};
std::vector<float> last_cell_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, -0.326822, 0.301121, 0.219523, 0.415242,
0.737121, 0.134902, -0.303595, 0.241948, 2.08242, 0.442513, 0.187127, 0.0577626,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto res_hs = outputs.at(0);
auto res_lho = outputs.at(1);
auto res_lco = outputs.at(2);
std::vector<float> hs_data;
std::vector<float> lho_data;
std::vector<float> lco_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
-0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.178681, -0.266999, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.182201,
-0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
-0.185038, -0.026845, 0.177273, -0.0774616, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0.294074,
-0.0319677, -0.0955337, 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626,
-0.0432316, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565,
0.269741, 0.130438, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669,
0.0868676, 0.044508, -0.373961, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0.022618, -0.121195,
-0.4065, -0.252054, -0.224905, 0.32421, 0.344048, 0.271694, -0.0300906,
-0.0890598, -0.135266, -0.0413375, -0.063456, 0.148524, 0.05108, -0.0234895,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.0252804, 0.267356, 0.146353,
0.0789186, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.0681467, 0.382748,
0.230211, -0.161537, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0};
std::vector<float> lho_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
std::vector<float> lco_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func)
{
std::size_t batch_size = 3;
......
......@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq.apply(m);
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref broadcast_scale(migraphx::module& m,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift)
const std::vector<std::size_t>& out_lens,
std::size_t axis)
{
auto lens = x->get_shape().lens();
if(scale->get_shape().lens() == out_lens)
return scale;
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
auto scale_lens = scale->get_shape().lens();
if(scale_lens.front() == 1 and scale_lens.size() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale);
return scale_mb;
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift,
std::size_t q_axis = 1)
{
auto lens = x->get_shape().lens();
auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
......@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale)
migraphx::instruction_ref scale,
std::size_t q_axis = 1)
{
auto lens = x->get_shape().lens();
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
auto lens = x->get_shape().lens();
auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
return m.add_instruction(migraphx::make_op(name), x, scale_mb);
}
migraphx::instruction_ref add_scale_mul(migraphx::module& m,
migraphx::instruction_ref scale1,
migraphx::instruction_ref scale2,
std::size_t axis1,
std::size_t axis2,
const std::vector<std::size_t>& out_lens)
{
auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1);
auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2);
return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb);
}
TEST_CASE(remove_qdq)
{
migraphx::shape sh1{migraphx::shape::float_type, {100, 100}};
......@@ -159,18 +180,62 @@ TEST_CASE(dot)
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 0);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 0);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(0.4f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 0);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 1);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
......@@ -178,6 +243,180 @@ TEST_CASE(dot)
EXPECT(m1 == m2);
}
TEST_CASE(dot_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_transposed)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_transposed_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::shape sh4{migraphx::shape::float_type, {1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_unsupported_axis)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 1);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 1);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_non_zero_point)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
......@@ -269,18 +508,18 @@ TEST_CASE(dot_add)
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add});
}
......@@ -320,26 +559,80 @@ TEST_CASE(conv)
auto weights = m2.add_parameter("weights", s4);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d6});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::shape s8{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto w_scale = m1.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, w_scale, zero, 0);
auto q1 = add_quantize_op(m1, "quantizelinear", input, inp_scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, inp_scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
d5,
d1);
m1.add_return({c1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto w_scale = m2.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q_inp = add_quantize_op(m2, "quantizelinear", input, inp_scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q_inp,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
m2.add_return({d6});
auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens());
auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d1});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale)
TEST_CASE(conv_multi_scale_unsupported_axis)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
......@@ -430,20 +723,20 @@ TEST_CASE(conv_bias_add)
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto b1 = m2.add_instruction(
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
m2.add_return({a1});
......@@ -519,22 +812,21 @@ TEST_CASE(conv_pooling_dot)
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction(
auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1);
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap =
......@@ -545,10 +837,11 @@ TEST_CASE(conv_pooling_dot)
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens());
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_3args_layout : verify_program<test_lstm_bidirct_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_last_layout : verify_program<test_lstm_bidirct_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_hs_layout : verify_program<test_lstm_forward_hs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_last_layout : verify_program<test_lstm_forward_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
len,
ih,
ic,
pph);
auto last_output =
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_cell_layout : verify_program<test_lstm_reverse_3args_cell_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
return p;
}
std::string section() const { return "rnn"; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment