Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
aaee3f1e
Commit
aaee3f1e
authored
Jan 23, 2019
by
Shucai Xiao
Browse files
add tests for rnn operator.
parent
69102b29
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
4 deletions
+102
-4
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+4
-4
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+98
-0
test/onnx/rnn_rnn1layer.onnx
test/onnx/rnn_rnn1layer.onnx
+0
-0
test/onnx/rnn_rnnbi3layers.onnx
test/onnx/rnn_rnnbi3layers.onnx
+0
-0
No files found.
src/include/migraphx/operators.hpp
View file @
aaee3f1e
...
@@ -1068,7 +1068,7 @@ struct rnn
...
@@ -1068,7 +1068,7 @@ struct rnn
};
};
std
::
size_t
hidden_size
=
1
;
std
::
size_t
hidden_size
=
1
;
operation
actv_func
=
tanh
{};
operation
actv_func
{
tanh
{}
}
;
rnn_direction_t
direction
=
forward
;
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
@@ -1076,14 +1076,14 @@ struct rnn
...
@@ -1076,14 +1076,14 @@ struct rnn
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
1
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
1
])
if
(
hidden_size
!=
hidden_dims
[
2
])
{
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
}
}
std
::
size_t
num_directions
=
1
;
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction_t
::
bidirectional
)
if
(
direction
==
bidirectional
)
{
{
num_directions
=
2
;
num_directions
=
2
;
}
}
...
...
test/cpu_ops_test.cpp
View file @
aaee3f1e
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include "test.hpp"
float
sigmoid
(
float
x
)
{
return
1
/
(
1
+
expf
(
-
x
));
}
float
sigmoid
(
float
x
)
{
return
1
/
(
1
+
expf
(
-
x
));
}
...
@@ -1326,4 +1327,101 @@ TEST_CASE(min_test)
...
@@ -1326,4 +1327,101 @@ TEST_CASE(min_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
/*
TEST_CASE(rnn_test)
{
{
migraphx::program p;
size_t hidden_size = 8;
size_t input_size = 6;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {1, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, input.data());
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
0.596363, -0.274248, 0.714484, 0.282515, 0.0938349,
0.185406, 0.283227, -0.482086, 0.265265, -0.523217,
0.50433, 0.400934, -0.34513, 0.114924, 0.0392658,
-0.0976029, 0.364322, -0.567117, 0.538775, 0.314859,
-0.478676, 0.51778, -0.286718, -0.0478341, 0.339601,
-0.380976, 0.628219, 0.222791, -0.271949, 0.490674,
-0.234456, -0.224984, 0.456527, -0.454559, 0.546034,
-0.0389027, -0.307475, 0.561003, -0.245673, -0.0776644,
0.447162, -0.52013, 0.511913, 0.0324621, -0.380515,
0.500777, -0.225695, -0.0193589, 0.458955, -0.531746,
0.448536, -0.087655, -0.430165, 0.551379, -0.161603,
-0.0165391, 0.447551, -0.491717, 0.484796, -0.0699652,
-0.3941, 0.561967, -0.168543, -0.0661258, 0.465925,
-0.499277, 0.45216, -0.103005, -0.392837, 0.584424,
-0.189044, -0.0388068, 0.468369, -0.512927, 0.449144,
-0.0900977, -0.400401, 0.573534, -0.19617, -0.0208253};
EXPECT(migraphx::verify_range(res, res_gold));
}
{
migraphx::program p;
size_t hidden_size = 6;
size_t input_size = 4;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {6, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, &input[0]);
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
-0.0890872, -0.0558751, 0.185233, 0.452857, 0.104082,
0.432953, 0.274236, 0.186055, -0.367716, 0.266761,
-0.28489, 0.498758, 0.0140574, -0.122377, 0.278067,
0.469699, 0.216743, 0.258926, 0.269785, 0.328379,
-0.576081, 0.11672, -0.452062, 0.603549, 0.472625,
0.120929, 0.350331, 0.502138, 0.103585, 0.128486,
0.0210318, 0.338759, -0.654448, 0.37656, -0.359715,
0.424365, 0.449677, 0.130903, 0.354359, 0.59317,
0.189543, 0.201865, 0.126288, 0.31099, -0.681538,
0.275407, -0.406133, 0.450767, 0.305638, 0.14942,
0.309857, 0.722745, 0.361199, -0.00963601, 0.397046,
0.264047, -0.539317, 0.0690505, -0.321901, 0.566638,
0.406511, 0.231472, 0.320225, 0.737927, 0.372938,
0.00762333, 0.349881, 0.280791, -0.541838, 0.128319,
-0.266702, 0.536205, 0.509004, 0.361068, 0.42431,
0.767474, 0.368881, 0.0753035, 0.141155, 0.219692,
-0.643801, 0.281643, -0.330984, 0.397033, 0.494424,
0.38013, 0.434627, 0.795404, 0.391589, 0.0102068,
0.166358, 0.226248, -0.608175, 0.302622, -0.349646,
0.375506, 0.546918, 0.22908, 0.40025, 0.806049,
0.424462, -0.0352604, 0.528827, -0.0372434, -0.573789,
-0.0541837, -0.194983, 0.552972, 0.553695, 0.263657,
0.432448, 0.815763, 0.412716, -0.0389366, 0.52391,
-0.0256845, -0.577296, -0.0570545, -0.219738, 0.561644};
EXPECT(migraphx::verify_range(res, res_gold));
}
}
*/
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/rnn_rnn1layer.onnx
0 → 100644
View file @
aaee3f1e
File added
test/onnx/rnn_rnnbi3layers.onnx
0 → 100644
View file @
aaee3f1e
File added
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment