Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9025504f
Commit
9025504f
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
rename rnn to vanilla_rnn
parent
657c6996
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+4
-4
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+8
-8
No files found.
src/include/migraphx/rewrite_rnn.hpp
View file @
9025504f
...
...
@@ -21,9 +21,9 @@ struct rewrite_rnn
void
apply
(
program
&
prog
)
const
;
private:
// for vall
in
a rnn operators
void
apply_vall
in
a_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
rnn_cell
(
bool
is_forward
,
// for va
ni
lla rnn operators
void
apply_va
ni
lla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_
rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
...
...
@@ -32,7 +32,7 @@ struct rewrite_rnn
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
rnn_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
vanilla_
rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
...
...
src/rewrite_rnn.cpp
View file @
9025504f
...
...
@@ -14,7 +14,7 @@ void rewrite_rnn::apply(program& prog) const
{
if
(
ins
->
name
()
==
"rnn"
)
{
apply_vall
in
a_rnn
(
prog
,
ins
);
apply_va
ni
lla_rnn
(
prog
,
ins
);
}
if
(
ins
->
name
()
==
"gru"
)
...
...
@@ -24,7 +24,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
void
rewrite_rnn
::
apply_vall
in
a_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
void
rewrite_rnn
::
apply_va
ni
lla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"rnn"
);
// could be 3 to 6 inputs, but the parse_rnn function will
...
...
@@ -40,7 +40,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
rnn_actv_funcs
(
ins
);
auto
actv_funcs
=
vanilla_
rnn_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
instruction_ref
last_output
{};
...
...
@@ -78,7 +78,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
rnn_cell
(
true
,
auto
ret_forward
=
vanilla_
rnn_cell
(
true
,
prog
,
ins
,
args
[
0
],
...
...
@@ -87,7 +87,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
bias_forward
,
ih_forward
,
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
rnn_cell
(
false
,
auto
ret_reverse
=
vanilla_
rnn_cell
(
false
,
prog
,
ins
,
args
[
0
],
...
...
@@ -147,7 +147,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
auto
ret
=
vanilla_
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
...
...
@@ -183,7 +183,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
rnn_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
vanilla_
rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
...
...
@@ -271,7 +271,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return
{
hidden_out
,
last_out
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
rnn_actv_funcs
(
instruction_ref
ins
)
const
std
::
vector
<
operation
>
rewrite_rnn
::
vanilla_
rnn_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// before rewrite the rnn operator, need to ensure
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment