Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
28f8b1e8
Commit
28f8b1e8
authored
Jan 30, 2019
by
Shucai Xiao
Browse files
fixed issues for rnn operator.
parent
c2b69817
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
34 deletions
+72
-34
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-3
src/program.cpp
src/program.cpp
+12
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+54
-25
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+3
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+2
-4
No files found.
src/include/migraphx/operators.hpp
View file @
28f8b1e8
...
...
@@ -395,7 +395,6 @@ struct concat
}
return
result
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
slice
...
...
@@ -698,8 +697,6 @@ struct gather
return
result
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
dot
...
...
src/program.cpp
View file @
28f8b1e8
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
...
...
@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
}
// TODO: Should it be an error if the output is empty?
if
(
ins
->
outputs
().
empty
())
{
...
...
src/rewrite_rnn.cpp
View file @
28f8b1e8
...
...
@@ -85,16 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]
);
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{
{
0
}
}
,
concat_output
);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
);
...
...
@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]
)
;
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
}
...
...
@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
// if rnn operator is executed, the last_output != prog.end()
if
(
last_output
!=
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
prog
.
replace_instruction
(
ins
,
last_output
);
last_output
=
prog
.
end
();
}
else
{
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: must put after rnn operator"
);
}
}
}
}
...
...
@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
instruction_ref
hidden_out
,
last_out
;
instruction_ref
hidden_out
=
prog
.
end
()
,
last_out
;
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
...
...
@@ -205,9 +227,15 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
// add the dimension of sequence length
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
ht
);
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
ht
);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if
(
i
<
seq_len
-
1
)
{
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
...
...
@@ -221,6 +249,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
std
::
vector
<
instruction_ref
>
out_args
;
out_args
.
push_back
(
hidden_out
);
...
...
src/targets/gpu/lowering.cpp
View file @
28f8b1e8
...
...
@@ -54,6 +54,7 @@ struct miopen_apply
program
*
prog
=
nullptr
;
context
ctx
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
void
check_shape
(
shape
x
,
instruction_ref
i
)
{
...
...
@@ -64,6 +65,7 @@ struct miopen_apply
void
init
()
{
this
->
last
=
instruction
::
get_output_alias
(
std
::
prev
(
prog
->
end
()));
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
...
...
@@ -112,7 +114,7 @@ struct miopen_apply
instruction_ref
insert_allocation
(
instruction_ref
ins
,
const
shape
&
s
,
std
::
string
tag
=
""
)
{
if
(
ins
==
--
prog
->
end
()
and
tag
.
empty
())
if
(
ins
==
last
and
tag
.
empty
())
{
return
prog
->
add_parameter
(
"output"
,
s
);
}
...
...
test/cpu_ops_test.cpp
View file @
28f8b1e8
...
...
@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
TEST_CASE
(
rnn_bidirectional
)
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
2
;
std
::
size_t
seq_len
=
1
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
...
...
test/gpu/miopen.cpp
View file @
28f8b1e8
...
...
@@ -1084,7 +1084,6 @@ struct test_rnn_forward
bias
,
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
}
...
...
@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
r
,
bias
,
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
}
...
...
@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
bias
,
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
}
...
...
@@ -1232,4 +1228,6 @@ int main()
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_reverse
>
();
verify_program
<
test_rnn_bidirectional
>
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment