Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60b3056e
Commit
60b3056e
authored
Jan 25, 2019
by
Shucai Xiao
Browse files
merge rnn operator changes.
parents
250a0243
128b0b65
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
144 additions
and
19 deletions
+144
-19
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+33
-13
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+12
-5
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+98
-0
test/onnx/rnn_rnn1layer.onnx
test/onnx/rnn_rnn1layer.onnx
+0
-0
test/onnx/rnn_rnnbi3layers.onnx
test/onnx/rnn_rnnbi3layers.onnx
+0
-0
No files found.
src/include/migraphx/operators.hpp
View file @
60b3056e
...
@@ -1075,7 +1075,7 @@ struct rnn
...
@@ -1075,7 +1075,7 @@ struct rnn
};
};
std
::
size_t
hidden_size
=
1
;
std
::
size_t
hidden_size
=
1
;
operation
actv_func
{
tanh
{}};
std
::
vector
<
operation
>
actv_func
s
{
tanh
{}};
rnn_direction_t
direction
=
forward
;
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
...
src/onnx/onnx.cpp
View file @
60b3056e
...
@@ -663,17 +663,6 @@ struct onnx_parser
...
@@ -663,17 +663,6 @@ struct onnx_parser
MIGRAPHX_THROW
(
"RNN: hidden size attribute missing"
);
MIGRAPHX_THROW
(
"RNN: hidden size attribute missing"
);
}
}
std
::
string
activation_func
=
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
}
if
(
map_actv_funcs
.
count
(
activation_func
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
}
// Handling of direction to be added later
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
if
(
contains
(
attributes
,
"direction"
))
...
@@ -691,6 +680,37 @@ struct onnx_parser
...
@@ -691,6 +680,37 @@ struct onnx_parser
dirct
=
op
::
rnn
::
reverse
;
dirct
=
op
::
rnn
::
reverse
;
}
}
std
::
vector
<
std
::
string
>
vec_names
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
});
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
fn
+
" not supported"
);
}
});
// bidirectional should have two activation functions
// if only one actv function is provides, we use it in both
// forward and reverse direction
if
(
dirct
==
op
::
rnn
::
bidirectional
)
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
std
::
vector
<
operation
>
vec_actv_funcs
;
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_actv_funcs
.
push_back
(
map_actv_funcs
[
fn
]);
});
// To be added later
// To be added later
float
clip
=
0.0
;
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
if
(
contains
(
attributes
,
"clip"
))
...
@@ -698,8 +718,8 @@ struct onnx_parser
...
@@ -698,8 +718,8 @@ struct onnx_parser
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
}
return
prog
.
add_instruction
(
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
},
op
::
rnn
{
hidden_size
,
map_actv_funcs
[
activation_func
],
dirct
,
clip
},
std
::
move
(
args
));
std
::
move
(
args
));
}
}
instruction_ref
instruction_ref
...
...
src/rewrite_rnn.cpp
View file @
60b3056e
...
@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_forward
,
trans_hw_forward
,
ih_forward
,
ih_forward
,
bias_forward
,
bias_forward
,
rnn_op
.
actv_func
);
rnn_op
.
actv_func
s
.
at
(
0
)
);
auto
ret_reverse
=
rnn_oper
(
false
,
auto
ret_reverse
=
rnn_oper
(
false
,
prog
,
prog
,
ins
,
ins
,
...
@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_reverse
,
trans_hw_reverse
,
ih_reverse
,
ih_reverse
,
bias_reverse
,
bias_reverse
,
rnn_op
.
actv_func
);
rnn_op
.
actv_func
s
.
at
(
1
)
);
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
...
@@ -128,7 +128,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -128,7 +128,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
else
else
{
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
)
?
true
:
false
;
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
forward
)
?
true
:
false
;
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// process input weight matrix
// process input weight matrix
auto
sxw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
1
]);
auto
sxw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
1
]);
...
@@ -160,8 +160,15 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -160,8 +160,15 @@ void rewrite_rnn::apply(program& prog) const
{
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
}
}
auto
ret
=
rnn_oper
(
auto
ret
=
rnn_oper
(
is_forward
,
is_forward
,
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_func
);
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_funcs
.
at
(
0
));
// add the dimension of num_direction
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
...
...
test/cpu_ops_test.cpp
View file @
60b3056e
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include "test.hpp"
float
sigmoid
(
float
x
)
{
return
1
/
(
1
+
expf
(
-
x
));
}
float
sigmoid
(
float
x
)
{
return
1
/
(
1
+
expf
(
-
x
));
}
...
@@ -1346,4 +1347,101 @@ TEST_CASE(min_test)
...
@@ -1346,4 +1347,101 @@ TEST_CASE(min_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
/*
TEST_CASE(rnn_test)
{
{
migraphx::program p;
size_t hidden_size = 8;
size_t input_size = 6;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {1, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, input.data());
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
0.596363, -0.274248, 0.714484, 0.282515, 0.0938349,
0.185406, 0.283227, -0.482086, 0.265265, -0.523217,
0.50433, 0.400934, -0.34513, 0.114924, 0.0392658,
-0.0976029, 0.364322, -0.567117, 0.538775, 0.314859,
-0.478676, 0.51778, -0.286718, -0.0478341, 0.339601,
-0.380976, 0.628219, 0.222791, -0.271949, 0.490674,
-0.234456, -0.224984, 0.456527, -0.454559, 0.546034,
-0.0389027, -0.307475, 0.561003, -0.245673, -0.0776644,
0.447162, -0.52013, 0.511913, 0.0324621, -0.380515,
0.500777, -0.225695, -0.0193589, 0.458955, -0.531746,
0.448536, -0.087655, -0.430165, 0.551379, -0.161603,
-0.0165391, 0.447551, -0.491717, 0.484796, -0.0699652,
-0.3941, 0.561967, -0.168543, -0.0661258, 0.465925,
-0.499277, 0.45216, -0.103005, -0.392837, 0.584424,
-0.189044, -0.0388068, 0.468369, -0.512927, 0.449144,
-0.0900977, -0.400401, 0.573534, -0.19617, -0.0208253};
EXPECT(migraphx::verify_range(res, res_gold));
}
{
migraphx::program p;
size_t hidden_size = 6;
size_t input_size = 4;
size_t batch_size = 2;
size_t seq_len = 5;
migraphx::shape hidden_shape{migraphx::shape::float_type, {6, batch_size, hidden_size}};
migraphx::shape input_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input(input_shape.elements(), 0.0);
input[0] = input[1] = 1.0;
std::vector<float> init_hidden(hidden_shape.elements(), 0.0);
p.compile(migraphx::cpu::target{});
migraphx::program::parameter_map m;
m["input"] = migraphx::argument(input_shape, &input[0]);
auto resarg = p.eval(m);
std::vector<float> res;
resarg.visit([&](auto output) { res.assign(output.begin(), output.end()); } );
std::vector<float> res_gold{
-0.0890872, -0.0558751, 0.185233, 0.452857, 0.104082,
0.432953, 0.274236, 0.186055, -0.367716, 0.266761,
-0.28489, 0.498758, 0.0140574, -0.122377, 0.278067,
0.469699, 0.216743, 0.258926, 0.269785, 0.328379,
-0.576081, 0.11672, -0.452062, 0.603549, 0.472625,
0.120929, 0.350331, 0.502138, 0.103585, 0.128486,
0.0210318, 0.338759, -0.654448, 0.37656, -0.359715,
0.424365, 0.449677, 0.130903, 0.354359, 0.59317,
0.189543, 0.201865, 0.126288, 0.31099, -0.681538,
0.275407, -0.406133, 0.450767, 0.305638, 0.14942,
0.309857, 0.722745, 0.361199, -0.00963601, 0.397046,
0.264047, -0.539317, 0.0690505, -0.321901, 0.566638,
0.406511, 0.231472, 0.320225, 0.737927, 0.372938,
0.00762333, 0.349881, 0.280791, -0.541838, 0.128319,
-0.266702, 0.536205, 0.509004, 0.361068, 0.42431,
0.767474, 0.368881, 0.0753035, 0.141155, 0.219692,
-0.643801, 0.281643, -0.330984, 0.397033, 0.494424,
0.38013, 0.434627, 0.795404, 0.391589, 0.0102068,
0.166358, 0.226248, -0.608175, 0.302622, -0.349646,
0.375506, 0.546918, 0.22908, 0.40025, 0.806049,
0.424462, -0.0352604, 0.528827, -0.0372434, -0.573789,
-0.0541837, -0.194983, 0.552972, 0.553695, 0.263657,
0.432448, 0.815763, 0.412716, -0.0389366, 0.52391,
-0.0256845, -0.577296, -0.0570545, -0.219738, 0.561644};
EXPECT(migraphx::verify_range(res, res_gold));
}
}
*/
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/rnn_rnn1layer.onnx
0 → 100644
View file @
60b3056e
File added
test/onnx/rnn_rnnbi3layers.onnx
0 → 100644
View file @
60b3056e
File added
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment