Commit 61775eab authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_fp8

parents a5c38ebe e7e5ba23
......@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
......
......@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm>
{
std::vector<op_desc> operators() const { return {{"LSTM"}}; }
......@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
}
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
......@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins);
}
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
......@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output};
}
};
......
......@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......
......@@ -4484,6 +4484,177 @@ def lrn_test():
return ([node], [x], [y])
@onnx_test()
def lstm_bi_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_bi_layout_last_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_hs_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs])
@onnx_test()
def lstm_r_layout_hs_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', 'output', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout])
@onnx_test()
def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
......
......@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE(lstm_forward_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs and last output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx");
EXPECT(p == prog);
}
// 8 args, cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
// activation functions
TEST_CASE(lstm_forward_actv_func)
{
......@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
}
}
TEST_CASE(lstm_reverse_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
auto prog = optimize_onnx("lstm_r_layout_test.onnx");
EXPECT(p == prog);
}
// 8 args, last and cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t sl = 5; // sequence len
......@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
}
}
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs)
{
std::size_t sl = 5; // sequence len
......
......@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
# from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
......
......@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b_lit = mm->add_literal(5.0);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale_a_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_a_lit);
auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(10.0f);
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
......@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto s1_mb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
......@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw_lit = else_mod->add_literal(1.66667f);
auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto ssw_mb = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
......
This diff is collapsed.
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_3args_layout : verify_program<test_lstm_bidirct_3args_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_bidirct_last_layout : verify_program<test_lstm_bidirct_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_hs_layout : verify_program<test_lstm_forward_hs_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_forward_last_layout : verify_program<test_lstm_forward_last_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", b_shape);
auto ih = mm->add_parameter("ih", ih_shape);
auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = mm->add_parameter("ic", ic_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto output = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq,
w,
r,
bias,
len,
ih,
ic,
pph);
auto last_output =
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
return p;
}
std::string section() const { return "rnn"; }
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_lstm_reverse_3args_cell_layout : verify_program<test_lstm_reverse_3args_cell_layout>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
seq,
w,
r);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
return p;
}
std::string section() const { return "rnn"; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment