Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7113cdfa
Commit
7113cdfa
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
fix a bug in rnn operator pass.
parent
20f89fcc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-2
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+10
-10
No files found.
src/onnx/onnx.cpp
View file @
7113cdfa
...
...
@@ -733,8 +733,8 @@ struct onnx_parser
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
//
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//
result.push_back(last_output);
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
}
...
...
src/rewrite_rnn.cpp
View file @
7113cdfa
...
...
@@ -136,15 +136,15 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
//
if (ins->name() == "rnn_last_output")
//
{
//
// if rnn operator is executed, the last_output != prog.end()
//
if (last_output != prog.end())
//
{
//
prog.replace_instruction(ins, op::identity{}, last_output);
//
last_output = prog.end();
//
}
//
}
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
// if rnn operator is executed, the last_output != prog.end()
if
(
last_output
!=
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
last_output
=
prog
.
end
();
}
}
}
}
...
...
@@ -161,7 +161,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// squeeze and transpose w
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tran_sw
=
prog
.
insert_instruction
(
s
w
,
op
::
transpose
{
perm
},
sw
);
auto
tran_sw
=
prog
.
insert_instruction
(
in
s
,
op
::
transpose
{
perm
},
sw
);
// squeeze and transpose r
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment