Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fb0fab94
Commit
fb0fab94
authored
Feb 04, 2019
by
Shucai Xiao
Browse files
Merge branch 'rnn_operator' into reshape_tests
parents
badf9b9c
2dfb4b0f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
37 deletions
+66
-37
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-2
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+2
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+43
-6
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+19
-29
No files found.
src/include/migraphx/operators.hpp
View file @
fb0fab94
...
...
@@ -1140,7 +1140,7 @@ struct rnn
};
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
tanh
{},
tanh
{}};
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
...
...
@@ -1190,7 +1190,7 @@ struct rnn_last_output
struct
undefined
{
std
::
string
name
()
const
{
return
"undefined"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
{};
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
fb0fab94
...
...
@@ -30,6 +30,8 @@ struct rewrite_rnn
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/rewrite_rnn.cpp
View file @
fb0fab94
...
...
@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
bidirectional
)
{
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
...
...
@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const
r_forward
,
bias_forward
,
ih_forward
,
rnn_op
.
actv_funcs
.
at
(
0
));
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
rnn_cell
(
false
,
prog
,
ins
,
...
...
@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const
r_reverse
,
bias_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
...
...
@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
);
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
forward
);
// input weight matrix
auto
w
=
args
[
1
];
...
...
@@ -134,8 +135,8 @@ void rewrite_rnn::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
...
...
@@ -263,5 +264,41 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return
{
hidden_out
,
last_out
};
}
std
::
vector
<
operation
>
rewrite_rnn
::
compute_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if
(
rnn_op
.
direction
==
op
::
rnn
::
bidirectional
)
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{},
op
::
tanh
{}};
}
else
if
(
rnn_op
.
actv_funcs
.
size
()
==
1
)
{
return
{
rnn_op
.
actv_funcs
.
at
(
0
),
rnn_op
.
actv_funcs
.
at
(
0
)};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
else
{
if
(
rnn_op
.
actv_funcs
.
empty
())
{
// default is tanh
return
{
op
::
tanh
{}};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/cpu_ops_test.cpp
View file @
fb0fab94
...
...
@@ -1458,10 +1458,7 @@ TEST_CASE(rnn_forward)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
forward
,
clip
},
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{},
migraphx
::
op
::
rnn
::
forward
,
clip
},
seq
,
w
,
r
,
...
...
@@ -1598,10 +1595,7 @@ TEST_CASE(rnn_reverse)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
seq
,
w
,
r
,
...
...
@@ -1723,16 +1717,14 @@ TEST_CASE(rnn_bidirectional)
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
hs_concat
=
p
.
eval
({});
std
::
vector
<
float
>
hs_data
;
...
...
@@ -1774,17 +1766,15 @@ TEST_CASE(rnn_bidirectional)
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment