Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f2436c3d
Commit
f2436c3d
authored
Jan 25, 2019
by
Shucai Xiao
Browse files
refine the processing of activation function attribute.
parent
e7e82505
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
22 deletions
+42
-22
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+38
-18
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+3
-3
No files found.
src/include/migraphx/operators.hpp
View file @
f2436c3d
...
@@ -1068,7 +1068,7 @@ struct rnn
...
@@ -1068,7 +1068,7 @@ struct rnn
};
};
std
::
size_t
hidden_size
=
1
;
std
::
size_t
hidden_size
=
1
;
operation
actv_func
{
tanh
{}};
std
::
vector
<
operation
>
actv_func
s
{
tanh
{}};
rnn_direction_t
direction
=
forward
;
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
...
src/onnx/onnx.cpp
View file @
f2436c3d
...
@@ -31,7 +31,7 @@ struct onnx_parser
...
@@ -31,7 +31,7 @@ struct onnx_parser
bool
is_pytorch
=
false
;
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
operation
>
actv_funcs
;
std
::
unordered_map
<
std
::
string
,
operation
>
map_
actv_funcs
;
onnx_parser
()
onnx_parser
()
{
{
...
@@ -93,11 +93,11 @@ struct onnx_parser
...
@@ -93,11 +93,11 @@ struct onnx_parser
void
init_actv_func
()
void
init_actv_func
()
{
{
actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
map_
actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -663,17 +663,6 @@ struct onnx_parser
...
@@ -663,17 +663,6 @@ struct onnx_parser
MIGRAPHX_THROW
(
"RNN: hidden size attribute missing"
);
MIGRAPHX_THROW
(
"RNN: hidden size attribute missing"
);
}
}
std
::
string
activation_func
=
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
activation_func
=
attributes
.
at
(
"activations"
).
strings
(
0
);
}
if
(
actv_funcs
.
count
(
activation_func
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
}
// Handling of direction to be added later
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
if
(
contains
(
attributes
,
"direction"
))
...
@@ -691,6 +680,37 @@ struct onnx_parser
...
@@ -691,6 +680,37 @@ struct onnx_parser
dirct
=
op
::
rnn
::
reverse
;
dirct
=
op
::
rnn
::
reverse
;
}
}
std
::
vector
<
std
::
string
>
vec_names
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
}
);
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
]
(
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
fn
+
" not supported"
);
}
});
// bidirectional should have two activation functions
// if only one actv function is provides, we use it in both
// forward and reverse direction
if
(
dirct
==
op
::
rnn
::
bidirectional
)
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
std
::
vector
<
operation
>
vec_actv_funcs
;
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
]
(
auto
&
fn
)
{
vec_actv_funcs
.
push_back
(
map_actv_funcs
[
fn
]);
});
// To be added later
// To be added later
float
clip
=
0.0
;
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
if
(
contains
(
attributes
,
"clip"
))
...
@@ -698,7 +718,7 @@ struct onnx_parser
...
@@ -698,7 +718,7 @@ struct onnx_parser
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
}
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
actv
_funcs
[
activation
_func
]
,
dirct
,
clip
},
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_
actv_func
s
,
dirct
,
clip
},
std
::
move
(
args
));
std
::
move
(
args
));
}
}
...
...
src/rewrite_rnn.cpp
View file @
f2436c3d
...
@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -106,7 +106,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_forward
,
trans_hw_forward
,
ih_forward
,
ih_forward
,
bias_forward
,
bias_forward
,
rnn_op
.
actv_func
);
rnn_op
.
actv_func
s
.
at
(
0
)
);
auto
ret_reverse
=
rnn_oper
(
false
,
auto
ret_reverse
=
rnn_oper
(
false
,
prog
,
prog
,
ins
,
ins
,
...
@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -115,7 +115,7 @@ void rewrite_rnn::apply(program& prog) const
trans_hw_reverse
,
trans_hw_reverse
,
ih_reverse
,
ih_reverse
,
bias_reverse
,
bias_reverse
,
rnn_op
.
actv_func
);
rnn_op
.
actv_func
s
.
at
(
1
)
);
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
...
@@ -161,7 +161,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -161,7 +161,7 @@ void rewrite_rnn::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
}
}
auto
ret
=
rnn_oper
(
auto
ret
=
rnn_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_func
);
is_forward
,
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_func
s
.
at
(
0
)
);
// add the dimension of num_direction
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment