Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20f89fcc
Commit
20f89fcc
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
clang format
parent
0fe4c56b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
39 deletions
+34
-39
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+31
-36
No files found.
src/onnx/onnx.cpp
View file @
20f89fcc
...
...
@@ -733,8 +733,8 @@ struct onnx_parser
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output);
//
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//
result.push_back(last_output);
return
result
;
}
...
...
src/rewrite_rnn.cpp
View file @
20f89fcc
...
...
@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
...
...
@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
...
...
@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
}
...
...
@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
// add the dimension of num_direction
...
...
@@ -140,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
//if (ins->name() == "rnn_last_output")
//
if (ins->name() == "rnn_last_output")
//{
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
...
...
@@ -175,13 +171,12 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment