Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20f89fcc
Commit
20f89fcc
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
clang format
parent
0fe4c56b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
39 deletions
+34
-39
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+31
-36
No files found.
src/onnx/onnx.cpp
View file @
20f89fcc
...
...
@@ -733,8 +733,8 @@ struct onnx_parser
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output);
//
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//
result.push_back(last_output);
return
result
;
}
...
...
src/rewrite_rnn.cpp
View file @
20f89fcc
...
...
@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
...
...
@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
...
...
@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
}
...
...
@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
// add the dimension of num_direction
...
...
@@ -140,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
//if (ins->name() == "rnn_last_output")
//
if (ins->name() == "rnn_last_output")
//{
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
...
...
@@ -175,13 +171,12 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment