Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60bbf654
Commit
60bbf654
authored
Apr 08, 2019
by
Shucai Xiao
Browse files
simplify the implementation of rnn operators with the enhanced gemm operator.
parent
6c77eae1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
57 deletions
+79
-57
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+79
-57
No files found.
src/rewrite_rnn.cpp
View file @
60bbf654
...
@@ -206,14 +206,16 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -206,14 +206,16 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
// bias
instruction_ref
bwb
{};
instruction_ref
brb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]
)
;
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
r
b
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
w
b
);
b
ias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
b
rb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
r
b
);
}
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
prog
.
end
();
...
@@ -225,21 +227,22 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -225,21 +227,22 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
instruction_ref
xt_wi
{};
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
instruction_ref
ht_ri
{};
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
bias
!=
prog
.
end
())
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
{
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
,
bwb
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
,
brb
);
}
}
else
else
{
{
ht
=
xt_ht
;
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
}
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// add the dimensions of sequence length (axis 0 for sequence length,
...
@@ -958,39 +961,38 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -958,39 +961,38 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_shape
=
sic
->
get_shape
();
auto
ic_shape
=
sic
->
get_shape
();
// bias
// bias
instruction_ref
bi_brcst
{};
instruction_ref
wbi_brcst
{},
r
bi_brcst
{};
instruction_ref
bo_brcst
{};
instruction_ref
wbo_brcst
{},
r
bo_brcst
{};
instruction_ref
bf_brcst
{};
instruction_ref
wbf_brcst
{},
r
bf_brcst
{};
instruction_ref
bc_brcst
{};
instruction_ref
wbc_brcst
{},
r
bc_brcst
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
b
x
i
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
w
bi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
b
h
i
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
r
bi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
bi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxi
,
bh
i
);
wbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
i
);
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bi
);
r
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bi
);
auto
b
x
o
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
w
bo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
h
o
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
r
bo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
bo
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxo
,
bh
o
);
wbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
o
);
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bo
);
r
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bo
);
auto
b
x
f
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
w
bf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
b
h
f
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
auto
r
bf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxf
,
bh
f
);
wbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
f
);
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bf
);
r
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bf
);
auto
b
x
c
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
w
bc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
b
h
c
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
r
bc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bc
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxc
,
bh
c
);
wbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
c
);
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bc
);
r
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bc
);
}
}
// peep hole
// peep hole
instruction_ref
pphi_brcst
{};
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
...
@@ -1011,43 +1013,58 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1011,43 +1013,58 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
instruction_ref
xt_wi
{},
ht_ri
{};
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
if
(
bias
!=
prog
.
end
())
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
,
wbi_brcst
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
,
rbi_brcst
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
}
if
(
bias
!=
prog
.
end
())
{
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
bi_brcst
);
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
instruction_ref
xt_wf
{},
ht_rf
{};
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
if
(
bias
!=
prog
.
end
())
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
,
wbf_brcst
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
,
rbf_brcst
);
}
else
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
}
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
}
if
(
bias
!=
prog
.
end
())
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
}
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
instruction_ref
xt_wc
{},
ht_rc
{};
auto
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ct_before_actv
,
bc_brcst
);
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
,
wbc_brcst
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
,
rbc_brcst
);
}
}
else
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
}
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
// equation Ct = ft (.) Ct-1 + it (.) ct
...
@@ -1057,18 +1074,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -1057,18 +1074,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
last_cell_output
=
cellt
;
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
instruction_ref
xt_wo
{},
ht_ro
{};
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
if
(
bias
!=
prog
.
end
())
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
,
wbo_brcst
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
,
rbo_brcst
);
}
else
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
}
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
prog
.
end
())
{
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
}
if
(
bias
!=
prog
.
end
())
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
// Ht = ot (.) h(Ct)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment