Commit 4d0c7238 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rnn_operator' into reshape_tests

parents 4fc35cf8 55f8e435
...@@ -771,17 +771,14 @@ struct onnx_parser ...@@ -771,17 +771,14 @@ struct onnx_parser
args.insert(args.end(), (6 - args.size()), ins); args.insert(args.end(), (6 - args.size()), ins);
} }
std::vector<instruction_ref> result;
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args)); std::move(args));
result.push_back(hidden_states);
// second output for the last hidden state // second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
result.push_back(last_output);
return result; return {hidden_states, last_output};
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -1452,6 +1452,8 @@ struct test_rnn_bi_3args ...@@ -1452,6 +1452,8 @@ struct test_rnn_bi_3args
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape); auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
...@@ -1540,4 +1542,5 @@ int main() ...@@ -1540,4 +1542,5 @@ int main()
verify_program<test_rnn_5args>(); verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>(); verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>(); verify_program<test_rnn_bidirectional10>();
// verify_program<test_rnn_bi_3args>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment