Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d4594903
Commit
d4594903
authored
Jan 23, 2019
by
Shucai Xiao
Browse files
fixed build error.
parent
1596cf1f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
16 deletions
+25
-16
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+4
-2
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+12
-13
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+5
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
No files found.
src/CMakeLists.txt
View file @
d4594903
...
@@ -12,6 +12,7 @@ add_library(migraphx
...
@@ -12,6 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
rewrite_gru.cpp
env.cpp
env.cpp
generate.cpp
generate.cpp
instruction.cpp
instruction.cpp
...
...
src/include/migraphx/rewrite_gru.hpp
View file @
d4594903
...
@@ -21,7 +21,7 @@ struct rewrite_gru
...
@@ -21,7 +21,7 @@ struct rewrite_gru
void
apply
(
program
&
prog
)
const
;
void
apply
(
program
&
prog
)
const
;
private:
private:
std
::
vector
<
instruction_ref
>
rnn_gru
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_oper
(
bool
is_forward
,
program
&
prog
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
input
,
...
@@ -29,7 +29,9 @@ struct rewrite_gru
...
@@ -29,7 +29,9 @@ struct rewrite_gru
instruction_ref
wh
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
bias
,
operation
&
actv_func
)
const
;
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/rewrite_gru.cpp
View file @
d4594903
...
@@ -33,17 +33,16 @@ void rewrite_gru::apply(program& prog) const
...
@@ -33,17 +33,16 @@ void rewrite_gru::apply(program& prog) const
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
{
long
hs
=
static_cast
<
long
>
(
hidden_size
);
// forward weight
// forward weight
auto
uw_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
uw_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
},
uw_forward
}
);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}
}
,
uw_forward
);
auto
ur_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
ur_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_forward
);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_forward
);
// reverse weight
// reverse weight
auto
uw_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
uw_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
},
uw_reverse
}
);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}
}
,
uw_reverse
);
auto
ur_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
ur_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_reverse
);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_reverse
);
...
@@ -92,7 +91,7 @@ void rewrite_gru::apply(program& prog) const
...
@@ -92,7 +91,7 @@ void rewrite_gru::apply(program& prog) const
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
rnn
_oper
(
false
,
auto
ret_reverse
=
gru
_oper
(
false
,
prog
,
prog
,
ins
,
ins
,
args
[
0
],
args
[
0
],
...
@@ -136,7 +135,7 @@ void rewrite_gru::apply(program& prog) const
...
@@ -136,7 +135,7 @@ void rewrite_gru::apply(program& prog) const
}
}
else
else
{
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
auto
ret
=
gru_oper
(
is_forward
,
auto
ret
=
gru_oper
(
is_forward
,
...
@@ -175,7 +174,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -175,7 +174,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
{
1
});
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
{
1
});
auto
l1
=
prog
.
add_literal
(
migraphx
::
l
e
teral
{
s
,
{
1
}});
auto
l1
=
prog
.
add_literal
(
migraphx
::
l
i
teral
{
s
,
{
1
}});
// weight matrix
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
...
@@ -199,12 +198,12 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -199,12 +198,12 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
{
{
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
...
@@ -245,7 +244,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -245,7 +244,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_r
t
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_r
h
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_bh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_bh
);
...
@@ -267,13 +266,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -267,13 +266,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_wbh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_wbh
);
}
}
}
}
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xwhh_rt
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xwhh_rt
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
1
zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
z
1
t
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
1
ztht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
1
zt
,
ht
);
auto
z
1
tht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
z
1
t
,
ht
);
auto
ztht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
ih
);
auto
ztht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
ih
);
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
1
ztht
ztht1
);
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
z
1
tht
,
ztht1
);
final_out
=
ih
;
final_out
=
ih
;
if
(
is_forward
)
if
(
is_forward
)
...
...
src/targets/cpu/target.cpp
View file @
d4594903
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -12,7 +13,10 @@ std::string target::name() const { return "cpu"; }
...
@@ -12,7 +13,10 @@ std::string target::name() const { return "cpu"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
{
{
return
{
auto_contiguous
{},
rewrite_rnn
{},
lowering
{}};
return
{
auto_contiguous
{},
rewrite_rnn
{},
rewrite_gru
{},
lowering
{}};
}
}
}
// namespace cpu
}
// namespace cpu
...
...
src/targets/gpu/target.cpp
View file @
d4594903
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
...
@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_rnn
{},
rewrite_rnn
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gru
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
dead_code_elimination
{},
dead_code_elimination
{},
constant_propagate
{},
constant_propagate
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment