Commit 83dbf407 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine processing of inputs of rnn operator.

parent 7bab863d
...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const ...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard()) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
......
...@@ -1187,6 +1187,18 @@ struct rnn_last_output ...@@ -1187,6 +1187,18 @@ struct rnn_last_output
} }
}; };
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -763,6 +763,14 @@ struct onnx_parser ...@@ -763,6 +763,14 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if (args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
...@@ -822,6 +830,12 @@ struct onnx_parser ...@@ -822,6 +830,12 @@ struct onnx_parser
} }
} }
void parse_undefined(const std::string &name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name) void parse_node(const std::string& name)
{ {
if(name.empty()) if(name.empty())
...@@ -832,25 +846,16 @@ struct onnx_parser ...@@ -832,25 +846,16 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
const std::string& op_type = node.op_type();
if((op_type == "RNN" || op_type == "LSTM" || op_type == "GRU") && input.empty())
{
continue;
}
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
assert(name != input); assert(name != input);
this->parse_node(input); this->parse_node(input);
args.push_back(instructions.at(input));
} }
else else if (input.empty())
{ {
args.push_back(instructions.at(input)); this->parse_undefined(input);
} }
args.push_back(instructions.at(input));
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
......
...@@ -16,10 +16,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -16,10 +16,10 @@ void rewrite_rnn::apply(program& prog) const
// rewrite rnn operator // rewrite rnn operator
if(ins->name() == "rnn") if(ins->name() == "rnn")
{ {
// could be 3 to 6 inputs, but the 5th input is undefined in // could be 3 to 6 inputs, but the parse_rnn function will
// pytorch exported onnx, and it is ignored by protobuf. So // append undefined operators to make 6 arguments when parsing
// for input arguments 5 and 6, we need to check the shape, // an onnx file. Another case is user can have only 3 arguments
// then based on the shape to judge the specific input info // when writing their program.
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
...@@ -44,7 +44,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -44,7 +44,7 @@ void rewrite_rnn::apply(program& prog) const
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end(); bias_forward = bias_reverse = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -53,12 +53,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -53,12 +53,10 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
auto arg_ih = (args.size() == 6) ? args[5] : args[4]; ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
} }
else else
{ {
...@@ -120,17 +118,16 @@ void rewrite_rnn::apply(program& prog) const ...@@ -120,17 +118,16 @@ void rewrite_rnn::apply(program& prog) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = (args.size() == 6) ? args[5] : args[4]; ih = args[5];
} }
else else
{ {
......
...@@ -1403,6 +1403,7 @@ TEST_CASE(rnn_forward) ...@@ -1403,6 +1403,7 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, rf_data}); auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1412,6 +1413,7 @@ TEST_CASE(rnn_forward) ...@@ -1412,6 +1413,7 @@ TEST_CASE(rnn_forward)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1453,6 +1455,7 @@ TEST_CASE(rnn_forward) ...@@ -1453,6 +1455,7 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, rf_data}); auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1463,6 +1466,7 @@ TEST_CASE(rnn_forward) ...@@ -1463,6 +1466,7 @@ TEST_CASE(rnn_forward)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
...@@ -1539,6 +1543,7 @@ TEST_CASE(rnn_reverse) ...@@ -1539,6 +1543,7 @@ TEST_CASE(rnn_reverse)
auto r = p.add_literal(migraphx::literal{r_shape, rr_data}); auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1548,6 +1553,7 @@ TEST_CASE(rnn_reverse) ...@@ -1548,6 +1553,7 @@ TEST_CASE(rnn_reverse)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1589,6 +1595,7 @@ TEST_CASE(rnn_reverse) ...@@ -1589,6 +1595,7 @@ TEST_CASE(rnn_reverse)
auto r = p.add_literal(migraphx::literal{r_shape, rr_data}); auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1599,6 +1606,7 @@ TEST_CASE(rnn_reverse) ...@@ -1599,6 +1606,7 @@ TEST_CASE(rnn_reverse)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
...@@ -1713,6 +1721,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1713,6 +1721,7 @@ TEST_CASE(rnn_bidirectional)
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end()); bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1722,6 +1731,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1722,6 +1731,7 @@ TEST_CASE(rnn_bidirectional)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1762,6 +1772,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1762,6 +1772,7 @@ TEST_CASE(rnn_bidirectional)
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end()); bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1772,6 +1783,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1772,6 +1783,7 @@ TEST_CASE(rnn_bidirectional)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
......
...@@ -1107,7 +1107,8 @@ struct test_rnn_forward ...@@ -1107,7 +1107,8 @@ struct test_rnn_forward
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1117,6 +1118,7 @@ struct test_rnn_forward ...@@ -1117,6 +1118,7 @@ struct test_rnn_forward
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1147,6 +1149,7 @@ struct test_rnn_forward10 ...@@ -1147,6 +1149,7 @@ struct test_rnn_forward10
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1157,6 +1160,7 @@ struct test_rnn_forward10 ...@@ -1157,6 +1160,7 @@ struct test_rnn_forward10
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1187,6 +1191,7 @@ struct test_rnn_reverse ...@@ -1187,6 +1191,7 @@ struct test_rnn_reverse
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1196,6 +1201,7 @@ struct test_rnn_reverse ...@@ -1196,6 +1201,7 @@ struct test_rnn_reverse
w, w,
r, r,
bias, bias,
und,
ih); ih);
return p; return p;
...@@ -1225,6 +1231,7 @@ struct test_rnn_reverse2 ...@@ -1225,6 +1231,7 @@ struct test_rnn_reverse2
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1234,6 +1241,7 @@ struct test_rnn_reverse2 ...@@ -1234,6 +1241,7 @@ struct test_rnn_reverse2
w, w,
r, r,
bias, bias,
und,
ih); ih);
return p; return p;
...@@ -1328,6 +1336,7 @@ struct test_rnn_5args ...@@ -1328,6 +1336,7 @@ struct test_rnn_5args
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1337,7 +1346,8 @@ struct test_rnn_5args ...@@ -1337,7 +1346,8 @@ struct test_rnn_5args
seq, seq,
w, w,
r, r,
bias); bias,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
...@@ -1367,6 +1377,7 @@ struct test_rnn_bidirectional ...@@ -1367,6 +1377,7 @@ struct test_rnn_bidirectional
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1377,6 +1388,7 @@ struct test_rnn_bidirectional ...@@ -1377,6 +1388,7 @@ struct test_rnn_bidirectional
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1407,7 +1419,7 @@ struct test_rnn_bidirectional10 ...@@ -1407,7 +1419,7 @@ struct test_rnn_bidirectional10
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1417,6 +1429,7 @@ struct test_rnn_bidirectional10 ...@@ -1417,6 +1429,7 @@ struct test_rnn_bidirectional10
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
......
...@@ -551,6 +551,7 @@ TEST_CASE(rnn_test) ...@@ -551,6 +551,7 @@ TEST_CASE(rnn_test)
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -559,7 +560,7 @@ TEST_CASE(rnn_test) ...@@ -559,7 +560,7 @@ TEST_CASE(rnn_test)
clip}, clip},
seq, seq,
w, w,
r); r, und, und, und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
...@@ -577,7 +578,9 @@ TEST_CASE(rnn_test) ...@@ -577,7 +578,9 @@ TEST_CASE(rnn_test)
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -588,7 +591,8 @@ TEST_CASE(rnn_test) ...@@ -588,7 +591,8 @@ TEST_CASE(rnn_test)
w, w,
r, r,
bias, bias,
ih); seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment