Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15eb1987
Unverified
Commit
15eb1987
authored
Jun 21, 2019
by
mvermeulen
Committed by
GitHub
Jun 21, 2019
Browse files
Merge pull request #273 from ROCmSoftwarePlatform/rnn_optimization
Rnn optimization
parents
f93eeca3
67c6e634
Pipeline
#672
failed with stages
in 0 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
182 deletions
+121
-182
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+3
-1
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+2
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-3
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+7
-1
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+94
-173
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+5
-4
No files found.
src/include/migraphx/op/binary.hpp
View file @
15eb1987
...
...
@@ -28,8 +28,10 @@ struct binary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
s1
=
args
[
0
].
get_shape
();
auto
s2
=
args
[
1
].
get_shape
();
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
if
(
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
if
(
s1
==
s2
and
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
...
...
src/include/migraphx/stringutils.hpp
View file @
15eb1987
...
...
@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline
std
::
string
to_upper
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
toupper
);
}
inline
std
::
string
to_lower
(
std
::
string
s
)
{
return
transform_string
(
std
::
move
(
s
),
::
tolower
);
}
inline
bool
starts_with
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
{
if
(
prefix
.
size
()
>
value
.
size
())
...
...
src/onnx/onnx.cpp
View file @
15eb1987
...
...
@@ -100,6 +100,7 @@ struct onnx_parser
void
init_actv_func
()
{
// Support name format of all lower case or the first letter capital
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
...
...
@@ -871,7 +872,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
...
...
@@ -962,7 +965,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
// need 4 activation functions
...
...
@@ -1089,7 +1094,9 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
name
)
{
return
to_lower
(
name
);
});
}
// need 6 activation functions for bidirectional directions
...
...
src/py/migraphx_py.cpp
View file @
15eb1987
...
...
@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
...
...
@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t
=
as
.
type_enum
();
n
=
sizeof
(
as
());
}
});
if
(
n
==
0
)
{
MIGRAPHX_THROW
(
"MIGRAPHX PYTHON: Unsupported data type"
+
info
.
format
);
}
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
return
n
>
0
?
i
/
n
:
0
;
...
...
src/rewrite_rnn.cpp
View file @
15eb1987
...
...
@@ -204,17 +204,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
instruction_ref
bb
{};
if
(
bias
!=
prog
.
end
())
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]
)
;
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
b
ias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
().
lens
()
},
b
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
b
b
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
_
lens
},
wr
b
);
}
instruction_ref
hidden_out
=
prog
.
end
();
...
...
@@ -228,20 +230,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
instruction_ref
ht
;
if
(
bias
!=
prog
.
end
())
{
ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_ht
,
bias
);
}
else
{
ht
=
xt_ht
;
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
bb
);
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_
ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
...
...
@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
migraphx
::
shape
s
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
in
t
>
data
(
s
.
elements
(),
1
);
std
::
vector
<
floa
t
>
data
(
s
.
elements
(),
1
.0
f
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
// w
eight
matrix
// w matrix
squeeze to 2-dim and do a transpose
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
// r slide to two part, zr and h
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rzr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
2
*
hs
}},
sr
);
auto
trzr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rzr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
size_t
bs
=
ih
->
get_shape
().
lens
()[
1
];
// bias
instruction_ref
brcst_bz
{};
instruction_ref
brcst_br
{};
instruction_ref
brcst_wbh
{};
instruction_ref
brcst_rbh
{};
instruction_ref
brcst_bh
{};
instruction_ref
bwb
{};
instruction_ref
brb_zr
{};
instruction_ref
brb_h
{};
if
(
bias
!=
prog
.
end
())
{
auto
broadcast_lens
=
sih
->
get_shape
().
lens
();
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brcst_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
brcst_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
broadcast_lens
},
bh
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
3
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
3
*
hs
)}},
wb
);
auto
rb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rb_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brb_zr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
2
*
hs
)}},
rb_zr
);
brb_h
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
static_cast
<
size_t
>
(
hs
)}},
rb_h
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
...
@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
if
(
bias
!=
prog
.
end
())
{
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_z
,
brcst_bz
);
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih1_rzr
,
brb_zr
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
x
t_
wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran
_w
r
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
if
(
bias
!=
prog
.
end
())
{
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_r
,
brcst_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht_r
);
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
xw
_
r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt
_w
);
auto
xw_h
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_w
);
auto
hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
ih1_rzr
);
auto
hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
ih1_rzr
);
auto
xw_hr_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_z
,
hr_z
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
w_hr_z
);
instruction_ref
xht_h
;
auto
xw_hr_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_r
,
hr_r
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
instruction_ref
hr_h
{};
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_bh
);
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
,
brb_h
);
}
else
{
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
trh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
);
instruction_ref
ht1_rh
{};
if
(
bias
!=
prog
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ht1_rh
,
brcst_rb
h
);
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trh
,
brb_
h
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
else
{
x
ht
_
h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xht_h
,
brcst_wb
h
);
ht
1_r
h
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tr
h
);
}
hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xht_h
);
auto
xw_hr_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xw_h
,
hr_h
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
...
...
@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx
::
shape
r_shape
=
r
->
get_shape
();
long
seq_len
=
static_cast
<
long
>
(
seq_shape
.
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r_shape
.
lens
()[
2
]);
auto
bs
=
ih
->
get_shape
().
lens
()[
1
];
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wi
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wi
);
auto
wo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wo
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wo
);
auto
wf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wf
);
// w matrix, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sw
);
auto
wc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sw
);
auto
tran_wc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wc
);
// r matrix
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
ri
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_ri
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ri
);
auto
ro
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_ro
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
ro
);
auto
rf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rf
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rf
);
auto
rc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sr
);
auto
tran_rc
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rc
);
// r matrix, squeeze and transpose
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
...
...
@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic_lens
=
sic
->
get_shape
().
lens
();
// bias
instruction_ref
bi_brcst
{};
instruction_ref
bo_brcst
{};
instruction_ref
bf_brcst
{};
instruction_ref
bc_brcst
{};
instruction_ref
wrb
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
bxi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
bhi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
bi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxi
,
bhi
);
bi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bi
);
auto
bxo
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
bho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
bo
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxo
,
bho
);
bo_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bo
);
auto
bxf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
bhf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
6
*
hs
},
{
7
*
hs
}},
sbias
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxf
,
bhf
);
bf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bf
);
auto
bxc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
bhc
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
7
*
hs
},
{
8
*
hs
}},
sbias
);
auto
bc
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
bxc
,
bhc
);
bc_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_lens
},
bc
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
ub_wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
4
*
hs
}},
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
8
*
hs
}},
sbias
);
auto
ub_wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ub_wb
,
ub_rb
);
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}},
ub_wrb
);
}
// peep hole
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
...
...
@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wi
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ri
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
}
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tsw
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tsr
);
auto
xt_sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_tsw
,
sih_tsr
);
if
(
bias
!=
prog
.
end
())
{
i
t_
before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
i
t_
before_actv
,
bi_brcst
);
x
t_
sih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
sih
,
wrb
);
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto
xt_wf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wf
);
auto
ht_rf
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rf
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wf
,
ht_rf
);
auto
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_sih
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
hs
},
{
2
*
hs
}},
xt_sih
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
2
*
hs
},
{
3
*
hs
}},
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
3
*
hs
},
{
4
*
hs
}},
xt_sih
);
if
(
pph
!=
prog
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
it_before_actv
,
pphi_ct
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
pphf_ct
);
}
if
(
bias
!=
prog
.
end
())
{
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_before_actv
,
bf_brcst
);
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto
xt_wc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wc
);
auto
ht_rc
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rc
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wc
,
ht_rc
);
if
(
bias
!=
prog
.
end
())
{
ct_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ct_before_actv
,
bc_brcst
);
}
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
...
...
@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
cellt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ft_cell
,
it_ct
);
last_cell_output
=
cellt
;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto
xt_wo
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wo
);
auto
ht_ro
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_ro
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wo
,
ht_ro
);
if
(
pph
!=
prog
.
end
())
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
ppho_cellt
);
}
if
(
bias
!=
prog
.
end
())
{
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ot_before_actv
,
bo_brcst
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
...
...
test/gpu/miopen.cpp
View file @
15eb1987
...
...
@@ -2666,10 +2666,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
output
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
migraphx
::
op
::
lstm
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
seq
,
w
,
r
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment