Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
67491293
Commit
67491293
authored
Jan 23, 2019
by
Shucai Xiao
Browse files
add the gru operator
parent
69102b29
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
5 deletions
+49
-5
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+49
-5
No files found.
src/include/migraphx/operators.hpp
View file @
67491293
...
...
@@ -1067,8 +1067,8 @@ struct rnn
bidirectional
,
};
std
::
size_t
hidden_size
=
1
;
operation
actv_func
=
tanh
{};
std
::
size_t
hidden_size
=
1
;
operation
actv_func
{
tanh
{}
}
;
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
...
...
@@ -1076,14 +1076,14 @@ struct rnn
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
1
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
1
])
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction_t
::
bidirectional
)
if
(
direction
==
bidirectional
)
{
num_directions
=
2
;
}
...
...
@@ -1101,6 +1101,50 @@ struct rnn
}
};
struct
gru
{
enum
gru_direction_t
{
forward
,
reverse
,
bidirectional
,
};
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{}};
gru_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
int
linear_before_reset
=
0
;
std
::
string
name
()
const
{
return
"gru"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"GRU: num_direction does not match the direction attribute"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment