Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
87697905
Commit
87697905
authored
Mar 25, 2019
by
Shucai Xiao
Browse files
optimize the gru operator with the enhanced dot operator.
parent
fa4f5244
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
44 deletions
+52
-44
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+52
-44
No files found.
src/rewrite_rnn.cpp
View file @
87697905
...
...
@@ -513,32 +513,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
brcst_bz
{};
instruction_ref
brcst_br
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
instruction_ref
wbz
{},
rbz
{};
instruction_ref
wbr
{},
rbr
{};
instruction_ref
wbh
{},
rbh
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bh
);
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
...
@@ -548,53 +535,74 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
if
(
bias
!=
prog
.
end
())
instruction_ref
xt_wz
{};
instruction_ref
ht_rz
{};
if
(
bias
!=
prog
.
end
())
{
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
,
wbz
);
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
,
rbz
);
}
else
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
}
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
instruction_ref
xt_wr
{};
instruction_ref
ht_rr
{};
if
(
bias
!=
prog
.
end
())
{
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
,
wbr
);
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
,
rbr
);
}
else
{
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
}
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
instruction_ref
xt_wh
{};
instruction_ref
rt_rh
{};
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
,
wbh
);
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
,
rbh
);
}
else
{
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
}
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
)
;
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
)
;
instruction_ref
ht1_rh
{}
;
instruction_ref
xt_wh
{}
;
if
(
bias
!=
prog
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rbh
);
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
,
wbh
);
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
,
rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
else
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wbh
);
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment