Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
18fc2362
Commit
18fc2362
authored
Jun 12, 2019
by
Shucai Xiao
Browse files
optimize the rewrite of gru operator to reduce the number of matrix multiplication calls.
parent
b8090620
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
87 deletions
+60
-87
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+59
-86
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+1
-1
No files found.
src/rewrite_rnn.cpp
View file @
18fc2362
...
@@ -489,58 +489,40 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -489,58 +489,40 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
in
t
>
data
(
s
.
elements
(),
1
);
std
::
vector
<
floa
t
>
data
(
s
.
elements
(),
1
.0
f
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// w
eight
matrix
// w matrix
squeeze to 2-dim and do a transpose
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
// r slide to two part, zr and h
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
rzr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
2
*
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
trzr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rzr
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
t
ran_
rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih_lens
=
s
ih
->
get_shape
().
lens
();
size_t
bs
=
ih
->
get_shape
().
lens
()
[
1
]
;
// bias
// bias
instruction_ref
bwbz
{};
instruction_ref
bwb
{};
instruction_ref
brbz
{};
instruction_ref
brb_zr
{};
instruction_ref
bwbr
{};
instruction_ref
brb_h
{};
instruction_ref
brbr
{};
instruction_ref
bwbh
{};
instruction_ref
brbh
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
3
*
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
3
*
hs
)}},
wb
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
bwbz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
wbz
);
auto
rb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
5
*
hs
}},
sbias
);
bwbr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
wbr
);
auto
rb_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
bwbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
wbh
);
brb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
2
*
hs
)}},
rb_zr
);
brb_h
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
hs
)}},
rb_h
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brbz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
rbz
);
brbr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
rbr
);
brbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
rbh
);
}
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
@@ -549,73 +531,64 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -549,73 +531,64 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
zt
{};
instruction_ref
xt_w
{};
instruction_ref
ht
{};
instruction_ref
ih1_rzr
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
,
bwb
);
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
,
bwbz
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
,
brb_zr
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
,
brbz
);
}
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
else
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
{
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
,
bwbr
);
}
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
,
brbr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
auto
xw_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_w
);
auto
xw_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_w
);
instruction_ref
xht_h
{};
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
,
bwbh
);
auto
hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
ih1_rzr
);
if
(
linear_before_reset
==
0
)
auto
hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
ih1_rzr
);
auto
xw_hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_z
,
hr_z
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_z
);
auto
xw_hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_r
,
hr_r
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
instruction_ref
hr_h
{};
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
if
(
bias
!=
prog
.
end
())
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
,
brb_h
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
,
brbh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
}
else
else
{
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
,
brbh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
}
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
}
}
else
else
{
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
instruction_ref
ht1_rh
{};
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
if
(
bias
!=
prog
.
end
())
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
;
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
if
(
linear_before_reset
==
0
)
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trh
,
brb_h
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
}
else
else
{
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
}
h
t
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_
h
);
h
r_h
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_r
h
);
}
}
auto
xw_hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_h
,
hr_h
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_zt
,
ht
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_zt
,
ht
);
...
...
test/gpu/miopen.cpp
View file @
18fc2362
...
@@ -2650,7 +2650,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
...
@@ -2650,7 +2650,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
output
=
p
.
add_instruction
(
auto
output
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hidden_size
,
migraphx
::
op
::
lstm
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
clip
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment