Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60bbf654
Commit
60bbf654
authored
Apr 08, 2019
by
Shucai Xiao
Browse files
simplify the implementation of rnn operators with the enhanced gemm operator.
parent
6c77eae1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
57 deletions
+79
-57
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+79
-57
No files found.
src/rewrite_rnn.cpp
View file @
60bbf654
...
...
@@ -206,14 +206,16 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
bwb
{};
instruction_ref
brb
{};
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]
)
;
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
r
b
);
b
ias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
w
b
);
b
rb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
r
b
);
}
instruction_ref
hidden_out
=
prog
.
end
();
...
...
@@ -225,21 +227,22 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
instruction_ref
xt_wi
{};
instruction_ref
ht_ri
{};
if
(
bias
!=
prog
.
end
())
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
,
bwb
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
,
brb
);
}
else
{
ht
=
xt_ht
;
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_
ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
...
...
@@ -958,39 +961,38 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_shape
=
sic
->
get_shape
();
// bias
instruction_ref
bi_brcst
{};
instruction_ref
bo_brcst
{};
instruction_ref
bf_brcst
{};
instruction_ref
bc_brcst
{};
instruction_ref
wbi_brcst
{},
r
bi_brcst
{};
instruction_ref
wbo_brcst
{},
r
bo_brcst
{};
instruction_ref
wbf_brcst
{},
r
bf_brcst
{};
instruction_ref
wbc_brcst
{},
r
bc_brcst
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
b
x
i
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
b
h
i
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
bi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxi
,
bh
i
);
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bi
);
auto
w
bi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
r
bi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
wbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
i
);
r
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bi
);
auto
b
x
o
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
h
o
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
bo
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxo
,
bh
o
);
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bo
);
auto
w
bo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
r
bo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
wbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
o
);
r
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bo
);
auto
b
x
f
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
b
h
f
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxf
,
bh
f
);
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bf
);
auto
w
bf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
r
bf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
wbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
f
);
r
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bf
);
auto
b
x
c
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
b
h
c
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bc
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxc
,
bh
c
);
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
bc
);
auto
w
bc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
r
bc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
wbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wb
c
);
r
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
r
bc
);
}
// peep hole
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
...
...
@@ -1011,43 +1013,58 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
instruction_ref
xt_wi
{},
ht_ri
{};
if
(
bias
!=
prog
.
end
())
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
,
wbi_brcst
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
,
rbi_brcst
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
if
(
bias
!=
prog
.
end
())
{
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
bi_brcst
);
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
instruction_ref
xt_wf
{},
ht_rf
{};
if
(
bias
!=
prog
.
end
())
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
,
wbf_brcst
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
,
rbf_brcst
);
}
else
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
}
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
if
(
pph
!=
prog
.
end
())
{
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
if
(
bias
!=
prog
.
end
())
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
}
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
auto
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
instruction_ref
xt_wc
{},
ht_rc
{};
if
(
bias
!=
prog
.
end
())
{
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ct_before_actv
,
bc_brcst
);
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
,
wbc_brcst
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
,
rbc_brcst
);
}
else
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
}
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
...
...
@@ -1057,18 +1074,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
instruction_ref
xt_wo
{},
ht_ro
{};
if
(
bias
!=
prog
.
end
())
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
,
wbo_brcst
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
,
rbo_brcst
);
}
else
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
}
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
if
(
bias
!=
prog
.
end
())
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment