"library/vscode:/vscode.git/clone" did not exist on "cd51732690641ae0ac76f90641246214f4a95bf9"
Commit 661f5b54 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add tests for rnn operator

parent 62044b86
...@@ -25,10 +25,10 @@ struct rewrite_rnn ...@@ -25,10 +25,10 @@ struct rewrite_rnn
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref w,
instruction_ref wh, instruction_ref r,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
}; };
......
...@@ -664,11 +664,11 @@ struct onnx_parser ...@@ -664,11 +664,11 @@ struct onnx_parser
if(contains(attributes, "hidden_size")) if(contains(attributes, "hidden_size"))
{ {
hidden_size = parse_value(attributes.at("hidden_size")).at<int>(); std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
} if (hidden_size != hidden_size_att)
else {
{ MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
MIGRAPHX_THROW("RNN: hidden size attribute missing"); }
} }
// Handling of direction to be added later // Handling of direction to be added later
...@@ -699,7 +699,7 @@ struct onnx_parser ...@@ -699,7 +699,7 @@ struct onnx_parser
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) { for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0) if(map_actv_funcs.count(fn) == 0)
{ {
MIGRAPHX_THROW("RNN: activation function " + fn + " not supported"); MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported");
} }
}); });
......
...@@ -97,7 +97,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -97,7 +97,7 @@ void rewrite_rnn::apply(program& prog) const
} }
else else
{ {
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false; bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward);
// input weight matrix // input weight matrix
auto w = args[1]; auto w = args[1];
......
...@@ -439,6 +439,163 @@ TEST_CASE(shape_gather_test) ...@@ -439,6 +439,163 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rnn_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirectional
{
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
}
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment