Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d7f3523
Commit
2d7f3523
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
rewrite the gru operator to support two outputs.
parent
1fbe8c48
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
235 additions
and
201 deletions
+235
-201
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-0
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+3
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-3
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+204
-193
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+2
-2
No files found.
src/include/migraphx/operators.hpp
View file @
2d7f3523
...
@@ -1167,6 +1167,20 @@ struct rnn_last_output
...
@@ -1167,6 +1167,20 @@ struct rnn_last_output
}
}
};
};
struct
gru_last_output
{
std
::
string
name
()
const
{
return
"gru_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
}
// namespace op
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_gru.hpp
View file @
2d7f3523
...
@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
struct
program
;
/**
/**
* Rewrite
rnn
to gemm and add.
* Rewrite
gru
to gemm
, mul,
and add.
*/
*/
struct
rewrite_gru
struct
rewrite_gru
{
{
...
@@ -21,14 +21,14 @@ struct rewrite_gru
...
@@ -21,14 +21,14 @@ struct rewrite_gru
void
apply
(
program
&
prog
)
const
;
void
apply
(
program
&
prog
)
const
;
private:
private:
std
::
vector
<
instruction_ref
>
gru_
oper
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_
cell
(
bool
is_forward
,
program
&
prog
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
;
operation
&
actv_func2
)
const
;
...
...
src/onnx/onnx.cpp
View file @
2d7f3523
...
@@ -732,14 +732,14 @@ struct onnx_parser
...
@@ -732,14 +732,14 @@ struct onnx_parser
std
::
move
(
args
));
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
// second out
put
for the last hidden state
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
result
.
push_back
(
last_output
);
return
result
;
return
result
;
}
}
instruction_ref
std
::
vector
<
instruction_ref
>
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
...
@@ -842,9 +842,18 @@ struct onnx_parser
...
@@ -842,9 +842,18 @@ struct onnx_parser
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
std
::
vector
<
instruction_ref
>
result
;
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
gru_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
}
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
...
...
src/rewrite_gru.cpp
View file @
2d7f3523
This diff is collapsed.
Click to expand it.
src/rewrite_rnn.cpp
View file @
2d7f3523
...
@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
...
@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
// rewrite the rnn_last_output operator that right after the rnn
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on
the
input to get
// operator. Intuitively, we can do a slice on
its
input to get
// the last output, but it is already existed in the rnn operator,
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
// so we can just use it as the output here
if
(
ins
->
name
()
==
"rnn_last_output"
)
if
(
ins
->
name
()
==
"rnn_last_output"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment