Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8b3027ba
Commit
8b3027ba
authored
Jun 19, 2019
by
Shucai Xiao
Browse files
more optimization of the rnn operators.
parent
93ca0082
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
22 deletions
+10
-22
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+10
-22
No files found.
src/rewrite_rnn.cpp
View file @
8b3027ba
...
@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
sih_lens
=
sih
->
get_shape
().
lens
();
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bwb
{};
instruction_ref
bb
{};
instruction_ref
brb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
}
,
w
b
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
r
b
);
b
r
b
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
rb
);
bb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
w
rb
);
}
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
prog
.
end
();
...
@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_wi
{}
;
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
)
;
instruction_ref
ht_ri
{}
;
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
)
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
,
bwb
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
bb
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
,
brb
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
}
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
...
@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_w
{}
;
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
)
;
instruction_ref
ih1_
rzr
{}
;
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
t
rzr
)
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
,
bwb
);
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
,
brb_zr
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih1_rzr
,
brb_zr
);
}
else
{
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
}
}
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment