Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13122f11
Commit
13122f11
authored
Jan 22, 2019
by
Shucai Xiao
Browse files
Code refinement and another change to support the RNN operator.
parent
0427ce2b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
8 deletions
+24
-8
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+24
-8
No files found.
src/onnx/onnx.cpp
View file @
13122f11
...
@@ -31,6 +31,7 @@ struct onnx_parser
...
@@ -31,6 +31,7 @@ struct onnx_parser
bool
is_pytorch
=
false
;
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
operation
>
actv_funcs
;
onnx_parser
()
onnx_parser
()
{
{
...
@@ -85,6 +86,16 @@ struct onnx_parser
...
@@ -85,6 +86,16 @@ struct onnx_parser
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
// init the activation function map
init_actv_func
();
}
void
init_actv_func
()
{
actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -656,12 +667,7 @@ struct onnx_parser
...
@@ -656,12 +667,7 @@ struct onnx_parser
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
}
}
std
::
unordered_map
<
std
::
string
,
operation
>
actv_func_map
;
if
(
actv_funcs
.
count
(
activation_func
)
==
0
)
actv_func_map
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
actv_func_map
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
actv_func_map
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
if
(
actv_func_map
.
count
(
activation_func
)
==
0
)
{
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
}
}
...
@@ -690,8 +696,8 @@ struct onnx_parser
...
@@ -690,8 +696,8 @@ struct onnx_parser
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
}
return
prog
.
add_instruction
(
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
actv_funcs
[
activation_func
],
dirct
,
clip
},
op
::
rnn
{
hidden_size
,
actv_func_map
[
activation_func
],
dirct
,
clip
},
std
::
move
(
args
));
std
::
move
(
args
));
}
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
...
@@ -750,6 +756,16 @@ struct onnx_parser
...
@@ -750,6 +756,16 @@ struct onnx_parser
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
// For RNN, LSTM, and GRU operators, one of the input arguments
// is prim::Undefined, and it is ignored by protobuf. We use a
// hack to ignore this argument for these three operators
std
::
string
op_type
=
node
.
op_type
();
if
((
op_type
==
"RNN"
||
op_type
==
"LSTM"
||
op_type
==
"GRU"
)
&&
input
.
empty
()
==
true
)
{
continue
;
}
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment