Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8b3027ba
Commit
8b3027ba
authored
Jun 19, 2019
by
Shucai Xiao
Browse files
more optimization of the rnn operators.
parent
93ca0082
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
22 deletions
+10
-22
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+10
-22
No files found.
src/rewrite_rnn.cpp
View file @
8b3027ba
...
@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto
sih_lens
=
sih
->
get_shape
().
lens
();
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
// bias
instruction_ref
bwb
{};
instruction_ref
bb
{};
instruction_ref
brb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
bwb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
}
,
w
b
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
r
b
);
b
r
b
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
rb
);
bb
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih_lens
},
w
rb
);
}
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
prog
.
end
();
...
@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
...
@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_wi
{}
;
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
)
;
instruction_ref
ht_ri
{}
;
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
)
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
,
bwb
);
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
bb
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
,
brb
);
}
else
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tran_sw
);
ht_ri
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
tran_sr
);
}
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_wi
,
ht_ri
);
...
@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
...
@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
seq
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
instruction_ref
xt_w
{}
;
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
)
;
instruction_ref
ih1_
rzr
{}
;
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
t
rzr
)
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
,
bwb
);
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
,
brb_zr
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih1_rzr
,
brb_zr
);
}
else
{
xt_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
tw
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
sih
,
trzr
);
}
}
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
auto
xw_z
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
1
},
{
0
},
{
hs
}},
xt_w
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment