Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
22f8a479
"vscode:/vscode.git/clone" did not exist on "948333240864dbf2f55bfb8c5d213d8d1cc55465"
Commit
22f8a479
authored
Feb 04, 2019
by
Shucai Xiao
Browse files
handling the cases that not enough actv functions are provided.
parent
0cc5b80e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
10 deletions
+50
-10
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+2
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+43
-5
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+4
-4
No files found.
src/include/migraphx/operators.hpp
View file @
22f8a479
...
...
@@ -1140,7 +1140,7 @@ struct rnn
};
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
tanh
{},
tanh
{}};
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
22f8a479
...
...
@@ -30,6 +30,8 @@ struct rewrite_rnn
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/rewrite_rnn.cpp
View file @
22f8a479
...
...
@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
bidirectional
)
{
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
...
...
@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const
r_forward
,
bias_forward
,
ih_forward
,
rnn_op
.
actv_funcs
.
at
(
0
));
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
rnn_cell
(
false
,
prog
,
ins
,
...
...
@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const
r_reverse
,
bias_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
...
...
@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
);
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
forward
);
// input weight matrix
auto
w
=
args
[
1
];
...
...
@@ -135,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
}
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
...
...
@@ -263,5 +264,42 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return
{
hidden_out
,
last_out
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
compute_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if
(
rnn_op
.
direction
==
op
::
rnn
::
bidirectional
)
{
if
(
rnn_op
.
actv_funcs
.
size
()
==
0
)
{
// default is tanh
return
{
op
::
tanh
{},
op
::
tanh
{}};
}
else
if
(
rnn_op
.
actv_funcs
.
size
()
==
1
)
{
return
{
rnn_op
.
actv_funcs
.
at
(
0
),
rnn_op
.
actv_funcs
.
at
(
0
)};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
else
{
if
(
rnn_op
.
actv_funcs
.
size
()
==
0
)
{
// default is tanh
return
{
op
::
tanh
{}};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/cpu_ops_test.cpp
View file @
22f8a479
...
...
@@ -1459,7 +1459,7 @@ TEST_CASE(rnn_forward)
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
forward
,
clip
},
seq
,
...
...
@@ -1599,7 +1599,7 @@ TEST_CASE(rnn_reverse)
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
seq
,
...
...
@@ -1724,7 +1724,7 @@ TEST_CASE(rnn_bidirectional)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
...
...
@@ -1776,7 +1776,7 @@ TEST_CASE(rnn_bidirectional)
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment