Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60d8b962
Commit
60d8b962
authored
Apr 08, 2019
by
Shucai Xiao
Browse files
fix errors for rnn optimization.
parent
dd26f1aa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
187 deletions
+24
-187
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+23
-17
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+1
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+0
-169
No files found.
src/rewrite_rnn.cpp
View file @
60d8b962
...
...
@@ -513,19 +513,25 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
wbz
{},
rbz
{};
instruction_ref
wbr
{},
rbr
{};
instruction_ref
wbh
{},
rbh
{};
instruction_ref
b
wbz
{},
b
rbz
{};
instruction_ref
b
wbr
{},
b
rbr
{};
instruction_ref
b
wbh
{},
b
rbh
{};
if
(
bias
!=
prog
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
bwbz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbz
);
bwbr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbr
);
bwbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
sbias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
sbias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
sbias
);
brbz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbz
);
brbr
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbr
);
brbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
rbh
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
...
...
@@ -539,30 +545,30 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
if
(
bias
!=
prog
.
end
())
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
,
wbz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
,
rbz
);
auto
xt_wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wz
,
b
wbz
);
auto
ht_rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rz
,
b
rbz
);
auto
xht_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wz
,
ht_rz
);
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
,
wbr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
,
rbr
);
auto
xt_wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wr
,
b
wbr
);
auto
ht_rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rr
,
b
rbr
);
auto
xht_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wr
,
ht_rr
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xht_r
);
instruction_ref
xht_h
{};
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
,
wbh
);
auto
xt_wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_wh
,
b
wbh
);
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
sih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
,
rbh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht1
,
tran_rh
,
b
rbh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
,
rbh
);
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_rh
,
b
rbh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ht1_rh
);
xht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wh
,
rt_rh
);
}
...
...
test/onnx/onnx_test.cpp
View file @
60d8b962
...
...
@@ -566,7 +566,7 @@ TEST_CASE(gemm_test)
auto
t0
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l0
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
alpha
=
2.
f
;
auto
beta
=
2
.0
f
;
auto
beta
=
1
.0
f
;
p
.
add_instruction
(
migraphx
::
op
::
dot
{
alpha
,
beta
},
t0
,
t1
);
auto
prog
=
migraphx
::
parse_onnx
(
"gemm_test.onnx"
);
...
...
test/op_shape_test.cpp
View file @
60d8b962
...
...
@@ -554,175 +554,6 @@ TEST_CASE(gemm)
}
}
// 3 input arguments
TEST_CASE
(
gemm
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
4
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
4
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
4
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
3
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
3
,
7
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}},
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
1
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
5
,
8
}};
migraphx
::
shape
s_m3
{
migraphx
::
shape
::
float_type
};
throws_shape
(
migraphx
::
op
::
dot
{},
s_m1
,
s_m2
,
s_m3
);
}
}
TEST_CASE
(
rnn
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment