Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
249f5024
Commit
249f5024
authored
Jun 14, 2019
by
Shucai Xiao
Browse files
optimize the lstm operator rewrite.
parent
a5941134
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
102 deletions
+36
-102
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+8
-1
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+28
-101
No files found.
src/py/migraphx_py.cpp
View file @
249f5024
...
...
@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
...
...
@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t
=
as
.
type_enum
();
n
=
sizeof
(
as
());
}
});
if
(
n
==
0
)
{
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: Unsupported data type"
+
info
.
format
);
}
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
return
n
>
0
?
i
/
n
:
0
;
...
...
@@ -134,6 +140,7 @@ PYBIND11_MODULE(migraphx, m)
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
auto
s
=
to_shape
(
info
);
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
})
.
def
(
"get_shape"
,
&
migraphx
::
argument
::
get_shape
)
...
...
src/rewrite_rnn.cpp
View file @
249f5024
...
...
@@ -903,36 +903,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix
// w matrix
, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wi
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wi
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
wo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wo
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wo
);
auto
wf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wf
);
auto
wc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sw
);
auto
tran_wc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wc
);
// r matrix
// r matrix, squeeze and transpose
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
ri
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_ri
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ri
);
auto
ro
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_ro
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ro
);
auto
rf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rf
);
auto
rc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sr
);
auto
tran_rc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rc
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
...
...
@@ -941,37 +922,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_lens
=
sic
->
get_shape
().
lens
();
// bias
instruction_ref
wbi_brcst
{};
instruction_ref
rbi_brcst
{};
instruction_ref
wbo_brcst
{};
instruction_ref
rbo_brcst
{};
instruction_ref
wbf_brcst
{};
instruction_ref
rbf_brcst
{};
instruction_ref
wbc_brcst
{};
instruction_ref
rbc_brcst
{};
instruction_ref
wb
{};
instruction_ref
rb
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rbi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
wbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
wbi
);
rbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
rbi
);
auto
wbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
wbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
wbo
);
rbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
rbo
);
auto
wbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
rbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
wbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
wbf
);
rbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
rbf
);
auto
wbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
wbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
wbc
);
rbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
rbc
);
auto
ub_wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
4
*
hs
}},
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
8
*
hs
}},
sbias
);
wb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wb
);
rb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_rb
);
}
// peep hole
...
...
@@ -997,30 +958,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_wi
{};
instruction_ref
ht_ri
{};
instruction_ref
xt_wf
{};
instruction_ref
ht_rf
{};
if
(
bias
!=
prog
.
end
())
instruction_ref
xt_sih
{};
if
(
bias
!=
prog
.
end
())
{
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
,
wbi_brcst
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
,
rbi_brcst
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
,
wbf_brcst
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
,
rbf_brcst
);
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
,
wb
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
,
rb
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
);
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_sih
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_sih
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
...
...
@@ -1031,22 +987,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
instruction_ref
xt_wc
{};
instruction_ref
ht_rc
{};
if
(
bias
!=
prog
.
end
())
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
,
wbc_brcst
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
,
rbc_brcst
);
}
else
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
}
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto
ft_cell
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ft
,
sic
);
...
...
@@ -1054,20 +995,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
instruction_ref
xt_wo
{};
instruction_ref
ht_ro
{};
if
(
bias
!=
prog
.
end
())
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
,
wbo_brcst
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
,
rbo_brcst
);
}
else
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
}
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment