Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
71273501
Commit
71273501
authored
Jan 30, 2019
by
Shucai Xiao
Browse files
clang format
parent
28f8b1e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
17 deletions
+22
-17
src/program.cpp
src/program.cpp
+2
-2
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+20
-15
No files found.
src/program.cpp
View file @
71273501
...
...
@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if
(
ins
->
name
()
==
"rnn_last_output"
)
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
...
...
src/rewrite_rnn.cpp
View file @
71273501
...
...
@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
...
...
@@ -134,10 +137,10 @@ void rewrite_rnn::apply(program& prog) const
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
...
...
@@ -204,7 +207,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
}
instruction_ref
hidden_out
=
prog
.
end
(),
last_out
;
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
...
...
@@ -234,19 +237,21 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if
(
i
<
seq_len
-
1
)
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment