Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
56daf147
Commit
56daf147
authored
Jan 24, 2019
by
Shucai Xiao
Browse files
fix a bug related to activation functions
parent
9ac6a4a8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
16 deletions
+53
-16
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+53
-16
No files found.
src/onnx/onnx.cpp
View file @
56daf147
...
@@ -31,7 +31,7 @@ struct onnx_parser
...
@@ -31,7 +31,7 @@ struct onnx_parser
bool
is_pytorch
=
false
;
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
operation
>
actv_funcs
;
std
::
unordered_map
<
std
::
string
,
operation
>
map_
actv_funcs
;
onnx_parser
()
onnx_parser
()
{
{
...
@@ -94,11 +94,11 @@ struct onnx_parser
...
@@ -94,11 +94,11 @@ struct onnx_parser
void
init_actv_func
()
void
init_actv_func
()
{
{
actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -669,7 +669,7 @@ struct onnx_parser
...
@@ -669,7 +669,7 @@ struct onnx_parser
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
}
}
if
(
actv_funcs
.
count
(
activation_func
)
==
0
)
if
(
map_
actv_funcs
.
count
(
activation_func
)
==
0
)
{
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
}
}
...
@@ -698,7 +698,7 @@ struct onnx_parser
...
@@ -698,7 +698,7 @@ struct onnx_parser
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
}
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
actv_funcs
[
activation_func
],
dirct
,
clip
},
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
map_
actv_funcs
[
activation_func
],
dirct
,
clip
},
std
::
move
(
args
));
std
::
move
(
args
));
}
}
...
@@ -734,25 +734,62 @@ struct onnx_parser
...
@@ -734,25 +734,62 @@ struct onnx_parser
dirct
=
op
::
gru
::
reverse
;
dirct
=
op
::
gru
::
reverse
;
}
}
std
::
vector
<
std
::
string
>
act_funcs
=
{
"sigmoid"
,
"tanh"
};
std
::
vector
<
std
::
string
>
act
v
_func
_name
s
=
{
"sigmoid"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
if
(
contains
(
attributes
,
"activations"
))
{
{
act_funcs
[
0
]
=
attributes
.
at
(
"activations"
).
strings
(
0
);
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
act_funcs
[
1
]
=
attributes
.
at
(
"activations"
).
strings
(
1
);
actv_func_names
.
clear
();
for
(
auto
&
fn
:
names
)
{
actv_func_names
.
push_back
(
fn
);
}
}
}
if
(
act_funcs
.
size
()
!=
2
)
if
(
act
v
_func
_name
s
.
size
()
!=
2
)
{
{
MIGRAPHX_THROW
(
"GRU: wrong activation function attribute"
);
MIGRAPHX_THROW
(
"GRU: wrong activation function attribute"
);
}
}
for
(
std
::
size_t
i
=
0
;
i
<
act_funcs
.
size
();
++
i
)
// need 4 activation functions
if
(
dirct
==
op
::
gru
::
bidirectional
)
{
// one name is provided, need to repeat the function 3 times
if
(
actv_func_names
.
size
()
==
1
)
{
actv_func_names
.
resize
(
4
,
actv_func_names
.
at
(
0
));
}
else
if
(
actv_func_names
.
size
()
==
2
)
{
{
if
(
actv_funcs
.
count
(
act_funcs
.
at
(
i
))
==
0
)
actv_func_names
.
insert
(
actv_func_names
.
end
(),
actv_func_names
.
begin
(),
actv_func_names
.
end
());
}
else
if
(
actv_func_names
.
size
()
==
3
)
{
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
act_funcs
.
at
(
i
)
+
" not supported"
);
MIGRAPHX_THROW
(
"GRU: birectional network cannot have 3 activation functions in attribute"
);
}
}
}
else
{
if
(
actv_func_names
.
size
()
==
1
)
{
actv_func_names
.
push_back
(
actv_func_names
.
at
(
0
));
}
}
}
for_each
(
actv_func_names
.
begin
(),
actv_func_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
name
+
" not supported"
);
}
});
std
::
vector
<
operation
>
vec_actv_funcs
;
for_each
(
actv_func_names
.
begin
(),
actv_func_names
.
end
(),
[
&
](
auto
&
name
)
{
vec_actv_funcs
.
push_back
(
map_actv_funcs
[
name
]);
});
// To be added later
// To be added later
float
clip
=
0.0
;
float
clip
=
0.0
;
...
@@ -769,7 +806,7 @@ struct onnx_parser
...
@@ -769,7 +806,7 @@ struct onnx_parser
return
prog
.
add_instruction
(
return
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
op
::
gru
{
hidden_size
,
{
actv_funcs
[
act_funcs
.
at
(
0
)],
actv_funcs
[
act_funcs
.
at
(
1
)]}
,
vec_actv_funcs
,
dirct
,
dirct
,
clip
,
clip
,
linear_before_reset
},
linear_before_reset
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment