Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b7f5e9bd
Commit
b7f5e9bd
authored
Feb 15, 2019
by
Shucai Xiao
Browse files
optimized lstm_rewrite
parent
4702c17e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
29 deletions
+35
-29
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+35
-29
No files found.
src/rewrite_rnn.cpp
View file @
b7f5e9bd
...
@@ -738,18 +738,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -738,18 +738,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
}
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph_forward
{}
;
instruction_ref
pph_forward
=
prog
.
end
()
;
instruction_ref
pph_reverse
{}
;
instruction_ref
pph_reverse
=
prog
.
end
()
;
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
{
{
pph_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
7
]);
pph_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
7
]);
pph_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
7
]);
pph_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
7
]);
}
}
else
{
pph_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
pph_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
}
auto
ret_forward
=
lstm_cell
(
auto
ret_forward
=
lstm_cell
(
true
,
true
,
...
@@ -830,15 +825,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -830,15 +825,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
}
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph
{}
;
instruction_ref
pph
=
prog
.
end
()
;
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
{
{
pph
=
args
[
7
];
pph
=
args
[
7
];
}
}
else
{
pph
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
}
auto
ret
=
lstm_cell
(
is_forward
,
auto
ret
=
lstm_cell
(
is_forward
,
prog
,
prog
,
...
@@ -991,18 +982,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -991,18 +982,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
}
}
// peep hole
// peep hole
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
instruction_ref
pphi_brcst
{};
auto
pphi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
spph
);
instruction_ref
ppho_brcst
{};
auto
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphi
);
instruction_ref
pphf_brcst
{};
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphi_brcst
);
auto
ppho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
spph
);
auto
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
ppho
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
ppho_brcst
);
auto
pphf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
spph
);
if
(
pph
!=
prog
.
end
())
auto
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphf
);
{
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphf_brcst
);
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
pphi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
spph
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphi
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphi_brcst
);
auto
ppho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
spph
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
ppho
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
ppho_brcst
);
auto
pphf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
spph
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphf
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphf_brcst
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
{
{
...
@@ -1013,9 +1011,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1013,9 +1011,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
bi_brcst
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
bi_brcst
);
...
@@ -1025,9 +1025,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1025,9 +1025,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
if
(
pph
!=
prog
.
end
())
{
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
...
@@ -1053,9 +1056,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1053,9 +1056,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
if
(
pph
!=
prog
.
end
())
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment