Commit aaee3f1e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add tests for rnn operator.

parent 69102b29
...@@ -1068,7 +1068,7 @@ struct rnn ...@@ -1068,7 +1068,7 @@ struct rnn
}; };
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
operation actv_func = tanh{}; operation actv_func{tanh{}};
rnn_direction_t direction = forward; rnn_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
...@@ -1076,14 +1076,14 @@ struct rnn ...@@ -1076,14 +1076,14 @@ struct rnn
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto in_dims = inputs[0].lens(); auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens(); auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[1]) if(hidden_size != hidden_dims[2])
{ {
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == rnn_direction_t::bidirectional) if(direction == bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
float sigmoid(float x) { return 1 / (1 + expf(-x)); } float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...@@ -1326,4 +1327,101 @@ TEST_CASE(min_test) ...@@ -1326,4 +1327,101 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
/*
TEST_CASE(rnn_test)
{
{
migraphx::program p;
size_t hidden_size = 8;
size_t input_size = 6;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {1, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, input.data());
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
0.596363, -0.274248, 0.714484, 0.282515, 0.0938349,
0.185406, 0.283227, -0.482086, 0.265265, -0.523217,
0.50433, 0.400934, -0.34513, 0.114924, 0.0392658,
-0.0976029, 0.364322, -0.567117, 0.538775, 0.314859,
-0.478676, 0.51778, -0.286718, -0.0478341, 0.339601,
-0.380976, 0.628219, 0.222791, -0.271949, 0.490674,
-0.234456, -0.224984, 0.456527, -0.454559, 0.546034,
-0.0389027, -0.307475, 0.561003, -0.245673, -0.0776644,
0.447162, -0.52013, 0.511913, 0.0324621, -0.380515,
0.500777, -0.225695, -0.0193589, 0.458955, -0.531746,
0.448536, -0.087655, -0.430165, 0.551379, -0.161603,
-0.0165391, 0.447551, -0.491717, 0.484796, -0.0699652,
-0.3941, 0.561967, -0.168543, -0.0661258, 0.465925,
-0.499277, 0.45216, -0.103005, -0.392837, 0.584424,
-0.189044, -0.0388068, 0.468369, -0.512927, 0.449144,
-0.0900977, -0.400401, 0.573534, -0.19617, -0.0208253};
EXPECT(migraphx::verify_range(res, res_gold));
}
{
migraphx::program p;
size_t hidden_size = 6;
size_t input_size = 4;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {6, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, &input[0]);
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
-0.0890872, -0.0558751, 0.185233, 0.452857, 0.104082,
0.432953, 0.274236, 0.186055, -0.367716, 0.266761,
-0.28489, 0.498758, 0.0140574, -0.122377, 0.278067,
0.469699, 0.216743, 0.258926, 0.269785, 0.328379,
-0.576081, 0.11672, -0.452062, 0.603549, 0.472625,
0.120929, 0.350331, 0.502138, 0.103585, 0.128486,
0.0210318, 0.338759, -0.654448, 0.37656, -0.359715,
0.424365, 0.449677, 0.130903, 0.354359, 0.59317,
0.189543, 0.201865, 0.126288, 0.31099, -0.681538,
0.275407, -0.406133, 0.450767, 0.305638, 0.14942,
0.309857, 0.722745, 0.361199, -0.00963601, 0.397046,
0.264047, -0.539317, 0.0690505, -0.321901, 0.566638,
0.406511, 0.231472, 0.320225, 0.737927, 0.372938,
0.00762333, 0.349881, 0.280791, -0.541838, 0.128319,
-0.266702, 0.536205, 0.509004, 0.361068, 0.42431,
0.767474, 0.368881, 0.0753035, 0.141155, 0.219692,
-0.643801, 0.281643, -0.330984, 0.397033, 0.494424,
0.38013, 0.434627, 0.795404, 0.391589, 0.0102068,
0.166358, 0.226248, -0.608175, 0.302622, -0.349646,
0.375506, 0.546918, 0.22908, 0.40025, 0.806049,
0.424462, -0.0352604, 0.528827, -0.0372434, -0.573789,
-0.0541837, -0.194983, 0.552972, 0.553695, 0.263657,
0.432448, 0.815763, 0.412716, -0.0389366, 0.52391,
-0.0256845, -0.577296, -0.0570545, -0.219738, 0.561644};
EXPECT(migraphx::verify_range(res, res_gold));
}
}
*/
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment