Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
22f8a479
Commit
22f8a479
authored
Feb 04, 2019
by
Shucai Xiao
Browse files
handling the cases that not enough actv functions are provided.
parent
0cc5b80e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
10 deletions
+50
-10
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+2
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+43
-5
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+4
-4
No files found.
src/include/migraphx/operators.hpp
View file @
22f8a479
...
@@ -1140,7 +1140,7 @@ struct rnn
...
@@ -1140,7 +1140,7 @@ struct rnn
};
};
std
::
size_t
hidden_size
=
1
;
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
tanh
{},
tanh
{}};
rnn_direction_t
direction
=
forward
;
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
22f8a479
...
@@ -30,6 +30,8 @@ struct rewrite_rnn
...
@@ -30,6 +30,8 @@ struct rewrite_rnn
instruction_ref
bias
,
instruction_ref
bias
,
instruction_ref
ih
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/rewrite_rnn.cpp
View file @
22f8a479
...
@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -29,9 +29,10 @@ void rewrite_rnn::apply(program& prog) const
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
std
::
vector
<
float
>
data
(
ih_shape
.
elements
(),
0
);
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
bidirectional
)
{
{
// input weight matrix
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
...
@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -72,7 +73,7 @@ void rewrite_rnn::apply(program& prog) const
r_forward
,
r_forward
,
bias_forward
,
bias_forward
,
ih_forward
,
ih_forward
,
rnn_op
.
actv_funcs
.
at
(
0
));
actv_funcs
.
at
(
0
));
auto
ret_reverse
=
rnn_cell
(
false
,
auto
ret_reverse
=
rnn_cell
(
false
,
prog
,
prog
,
ins
,
ins
,
...
@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -81,7 +82,7 @@ void rewrite_rnn::apply(program& prog) const
r_reverse
,
r_reverse
,
bias_reverse
,
bias_reverse
,
ih_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
1
));
auto
concat_output
=
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
...
@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -109,7 +110,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
else
else
{
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
);
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
forward
);
// input weight matrix
// input weight matrix
auto
w
=
args
[
1
];
auto
w
=
args
[
1
];
...
@@ -135,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -135,7 +136,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
auto
ret
=
rnn_cell
(
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// following logic is to ensure the last instruction is a
...
@@ -263,5 +264,42 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -263,5 +264,42 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return
{
hidden_out
,
last_out
};
return
{
hidden_out
,
last_out
};
}
}
std
::
vector
<
operation
>
rewrite_rnn
::
compute_actv_funcs
(
instruction_ref
ins
)
const
{
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if
(
rnn_op
.
direction
==
op
::
rnn
::
bidirectional
)
{
if
(
rnn_op
.
actv_funcs
.
size
()
==
0
)
{
// default is tanh
return
{
op
::
tanh
{},
op
::
tanh
{}};
}
else
if
(
rnn_op
.
actv_funcs
.
size
()
==
1
)
{
return
{
rnn_op
.
actv_funcs
.
at
(
0
),
rnn_op
.
actv_funcs
.
at
(
0
)};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
else
{
if
(
rnn_op
.
actv_funcs
.
size
()
==
0
)
{
// default is tanh
return
{
op
::
tanh
{}};
}
else
{
return
rnn_op
.
actv_funcs
;
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
test/cpu_ops_test.cpp
View file @
22f8a479
...
@@ -1459,7 +1459,7 @@ TEST_CASE(rnn_forward)
...
@@ -1459,7 +1459,7 @@ TEST_CASE(rnn_forward)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
forward
,
migraphx
::
op
::
rnn
::
forward
,
clip
},
clip
},
seq
,
seq
,
...
@@ -1599,7 +1599,7 @@ TEST_CASE(rnn_reverse)
...
@@ -1599,7 +1599,7 @@ TEST_CASE(rnn_reverse)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
reverse
,
migraphx
::
op
::
rnn
::
reverse
,
clip
},
clip
},
seq
,
seq
,
...
@@ -1724,7 +1724,7 @@ TEST_CASE(rnn_bidirectional)
...
@@ -1724,7 +1724,7 @@ TEST_CASE(rnn_bidirectional)
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}
},
{},
migraphx
::
op
::
rnn
::
bidirectional
,
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
@@ -1776,7 +1776,7 @@ TEST_CASE(rnn_bidirectional)
...
@@ -1776,7 +1776,7 @@ TEST_CASE(rnn_bidirectional)
auto
out_hs
=
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
clip
},
seq
,
seq
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment