Commit 11e79f52 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from branch rnn_operator

parents 1e7457cb 5f441536
...@@ -41,6 +41,14 @@ ...@@ -41,6 +41,14 @@
<summary>Macros must be prefixed with MIGRAPHX_</summary> <summary>Macros must be prefixed with MIGRAPHX_</summary>
</message> </message>
</rule> </rule>
<rule>
<pattern>mutable \w+</pattern>
<message>
<id>MutableVariable</id>
<severity>style</severity>
<summary>Do not create mutable variables.</summary>
</message>
</rule>
<rule> <rule>
<pattern>(memcpy|strcpy|strncpy|strcat|strncat) \(</pattern> <pattern>(memcpy|strcpy|strncpy|strcat|strncat) \(</pattern>
<message> <message>
......
pfultz2/rocm-recipes pfultz2/rocm-recipes
pcre danmar/cppcheck@681cb7dd909d1bfe41796b7616e43265177b9464 -DHAVE_RULES=1
danmar/cppcheck@575f62f39c1130f412d3cc11b0138c5057c451c0 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@fc22ef991ce7cb15821c8ccb4f03cdfc3e7e43cf ROCm-Developer-Tools/HIP@fc22ef991ce7cb15821c8ccb4f03cdfc3e7e43cf
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const ...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard()) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
......
...@@ -1245,6 +1245,18 @@ struct gru_last_output ...@@ -1245,6 +1245,18 @@ struct gru_last_output
} }
}; };
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -54,7 +54,6 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -54,7 +54,6 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
f(i); f(i);
} }
}); });
// cppcheck-suppress unreadVariable
work += grainsize; work += grainsize;
return result; return result;
}); });
......
...@@ -763,6 +763,14 @@ struct onnx_parser ...@@ -763,6 +763,14 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
...@@ -939,6 +947,12 @@ struct onnx_parser ...@@ -939,6 +947,12 @@ struct onnx_parser
} }
} }
void parse_undefined(const std::string& name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name) void parse_node(const std::string& name)
{ {
if(name.empty()) if(name.empty())
...@@ -949,25 +963,16 @@ struct onnx_parser ...@@ -949,25 +963,16 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
const std::string& op_type = node.op_type();
if((op_type == "RNN" || op_type == "LSTM" || op_type == "GRU") && input.empty())
{
continue;
}
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
assert(name != input); assert(name != input);
this->parse_node(input); this->parse_node(input);
args.push_back(instructions.at(input));
} }
else else if(input.empty())
{ {
args.push_back(instructions.at(input)); this->parse_undefined(input);
} }
args.push_back(instructions.at(input));
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
......
...@@ -16,10 +16,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -16,10 +16,10 @@ void rewrite_rnn::apply(program& prog) const
// rewrite rnn operator // rewrite rnn operator
if(ins->name() == "rnn") if(ins->name() == "rnn")
{ {
// could be 3 to 6 inputs, but the 5th input is undefined in // could be 3 to 6 inputs, but the parse_rnn function will
// pytorch exported onnx, and it is ignored by protobuf. So // append undefined operators to make 6 arguments when parsing
// for input arguments 5 and 6, we need to check the shape, // an onnx file. Another case is user can have only 3 arguments
// then based on the shape to judge the specific input info // when writing their program.
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
...@@ -44,7 +44,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -44,7 +44,7 @@ void rewrite_rnn::apply(program& prog) const
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end(); bias_forward = bias_reverse = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -53,12 +53,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -53,12 +53,10 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
auto arg_ih = (args.size() == 6) ? args[5] : args[4]; ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
} }
else else
{ {
...@@ -120,17 +118,16 @@ void rewrite_rnn::apply(program& prog) const ...@@ -120,17 +118,16 @@ void rewrite_rnn::apply(program& prog) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4) if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 || if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = (args.size() == 6) ? args[5] : args[4]; ih = args[5];
} }
else else
{ {
......
...@@ -1403,6 +1403,7 @@ TEST_CASE(rnn_forward) ...@@ -1403,6 +1403,7 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, rf_data}); auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1412,6 +1413,7 @@ TEST_CASE(rnn_forward) ...@@ -1412,6 +1413,7 @@ TEST_CASE(rnn_forward)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1453,6 +1455,7 @@ TEST_CASE(rnn_forward) ...@@ -1453,6 +1455,7 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, rf_data}); auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1463,6 +1466,7 @@ TEST_CASE(rnn_forward) ...@@ -1463,6 +1466,7 @@ TEST_CASE(rnn_forward)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
...@@ -1539,6 +1543,7 @@ TEST_CASE(rnn_reverse) ...@@ -1539,6 +1543,7 @@ TEST_CASE(rnn_reverse)
auto r = p.add_literal(migraphx::literal{r_shape, rr_data}); auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1548,6 +1553,7 @@ TEST_CASE(rnn_reverse) ...@@ -1548,6 +1553,7 @@ TEST_CASE(rnn_reverse)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1589,6 +1595,7 @@ TEST_CASE(rnn_reverse) ...@@ -1589,6 +1595,7 @@ TEST_CASE(rnn_reverse)
auto r = p.add_literal(migraphx::literal{r_shape, rr_data}); auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data}); auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1599,6 +1606,7 @@ TEST_CASE(rnn_reverse) ...@@ -1599,6 +1606,7 @@ TEST_CASE(rnn_reverse)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
...@@ -1713,6 +1721,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1713,6 +1721,7 @@ TEST_CASE(rnn_bidirectional)
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end()); bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1722,6 +1731,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1722,6 +1731,7 @@ TEST_CASE(rnn_bidirectional)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -1762,6 +1772,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1762,6 +1772,7 @@ TEST_CASE(rnn_bidirectional)
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end()); bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1772,6 +1783,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -1772,6 +1783,7 @@ TEST_CASE(rnn_bidirectional)
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
...@@ -1781,6 +1793,64 @@ TEST_CASE(rnn_bidirectional) ...@@ -1781,6 +1793,64 @@ TEST_CASE(rnn_bidirectional)
std::vector<float> last_output_data; std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704, std::vector<float> last_output_data_gold{0.03445704,
0.19167931, 0.19167931,
-0.3946827, -0.3946827,
......
...@@ -1107,6 +1107,7 @@ struct test_rnn_forward ...@@ -1107,6 +1107,7 @@ struct test_rnn_forward
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1117,6 +1118,7 @@ struct test_rnn_forward ...@@ -1117,6 +1118,7 @@ struct test_rnn_forward
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1147,6 +1149,7 @@ struct test_rnn_forward10 ...@@ -1147,6 +1149,7 @@ struct test_rnn_forward10
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1157,6 +1160,7 @@ struct test_rnn_forward10 ...@@ -1157,6 +1160,7 @@ struct test_rnn_forward10
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1187,6 +1191,7 @@ struct test_rnn_reverse ...@@ -1187,6 +1191,7 @@ struct test_rnn_reverse
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1196,6 +1201,7 @@ struct test_rnn_reverse ...@@ -1196,6 +1201,7 @@ struct test_rnn_reverse
w, w,
r, r,
bias, bias,
und,
ih); ih);
return p; return p;
...@@ -1225,6 +1231,7 @@ struct test_rnn_reverse2 ...@@ -1225,6 +1231,7 @@ struct test_rnn_reverse2
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1234,6 +1241,7 @@ struct test_rnn_reverse2 ...@@ -1234,6 +1241,7 @@ struct test_rnn_reverse2
w, w,
r, r,
bias, bias,
und,
ih); ih);
return p; return p;
...@@ -1328,6 +1336,7 @@ struct test_rnn_5args ...@@ -1328,6 +1336,7 @@ struct test_rnn_5args
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1337,7 +1346,8 @@ struct test_rnn_5args ...@@ -1337,7 +1346,8 @@ struct test_rnn_5args
seq, seq,
w, w,
r, r,
bias); bias,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
...@@ -1367,6 +1377,7 @@ struct test_rnn_bidirectional ...@@ -1367,6 +1377,7 @@ struct test_rnn_bidirectional
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
...@@ -1377,6 +1388,7 @@ struct test_rnn_bidirectional ...@@ -1377,6 +1388,7 @@ struct test_rnn_bidirectional
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1407,7 +1419,7 @@ struct test_rnn_bidirectional10 ...@@ -1407,7 +1419,7 @@ struct test_rnn_bidirectional10
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape); auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -1417,6 +1429,7 @@ struct test_rnn_bidirectional10 ...@@ -1417,6 +1429,7 @@ struct test_rnn_bidirectional10
w, w,
r, r,
bias, bias,
und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
...@@ -1424,6 +1437,39 @@ struct test_rnn_bidirectional10 ...@@ -1424,6 +1437,39 @@ struct test_rnn_bidirectional10
} }
}; };
struct test_rnn_bi_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
int main() int main()
{ {
verify_program<test_pooling_autopad>(); verify_program<test_pooling_autopad>();
......
...@@ -551,6 +551,7 @@ TEST_CASE(rnn_test) ...@@ -551,6 +551,7 @@ TEST_CASE(rnn_test)
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}}); auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -559,7 +560,10 @@ TEST_CASE(rnn_test) ...@@ -559,7 +560,10 @@ TEST_CASE(rnn_test)
clip}, clip},
seq, seq,
w, w,
r); r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
...@@ -577,7 +581,9 @@ TEST_CASE(rnn_test) ...@@ -577,7 +581,9 @@ TEST_CASE(rnn_test)
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}}); auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias = auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}}); p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
...@@ -588,7 +594,8 @@ TEST_CASE(rnn_test) ...@@ -588,7 +594,8 @@ TEST_CASE(rnn_test)
w, w,
r, r,
bias, bias,
ih); seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment