Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fe30f007
Commit
fe30f007
authored
Apr 08, 2019
by
Shucai Xiao
Browse files
clang format
parent
60bbf654
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
42 deletions
+42
-42
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+42
-42
No files found.
src/rewrite_rnn.cpp
View file @
fe30f007
...
...
@@ -214,8 +214,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wb
);
brb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rb
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wb
);
brb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rb
);
}
instruction_ref
hidden_out
=
prog
.
end
();
...
...
@@ -229,7 +229,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_wi
{};
instruction_ref
ht_ri
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
,
bwb
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
,
brb
);
...
...
@@ -237,13 +237,13 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_ht
);
sih
=
ht
;
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
...
...
@@ -970,23 +970,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rbi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
wbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbi
);
rbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbi
);
auto
wbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
wbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbo
);
rbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbo
);
auto
wbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
rbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
wbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbf
);
rbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbf
);
auto
wbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
wbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbc
);
rbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbc
);
wbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbi
);
rbi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbi
);
auto
wbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rbo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
wbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbo
);
rbo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbo
);
auto
wbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
rbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
wbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbf
);
rbf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbf
);
auto
wbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
wbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
wbc
);
rbc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
rbc
);
}
// peep hole
...
...
@@ -1014,15 +1014,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
instruction_ref
xt_wi
{},
ht_ri
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
,
wbi_brcst
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
,
rbi_brcst
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
,
wbi_brcst
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
,
rbi_brcst
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
...
...
@@ -1034,15 +1034,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
instruction_ref
xt_wf
{},
ht_rf
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
,
wbf_brcst
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
,
rbf_brcst
);
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
,
wbf_brcst
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
,
rbf_brcst
);
}
else
{
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
}
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
if
(
pph
!=
prog
.
end
())
...
...
@@ -1056,16 +1056,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref
xt_wc
{},
ht_rc
{};
if
(
bias
!=
prog
.
end
())
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
,
wbc_brcst
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
,
rbc_brcst
);
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
,
wbc_brcst
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
,
rbc_brcst
);
}
else
{
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
}
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto
ft_cell
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ft
,
sic
);
...
...
@@ -1077,13 +1077,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref
xt_wo
{},
ht_ro
{};
if
(
bias
!=
prog
.
end
())
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
,
wbo_brcst
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
,
rbo_brcst
);
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
,
wbo_brcst
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
,
rbo_brcst
);
}
else
{
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
}
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment