Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
28f8b1e8
Commit
28f8b1e8
authored
Jan 30, 2019
by
Shucai Xiao
Browse files
fixed issues for rnn operator.
parent
c2b69817
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
34 deletions
+72
-34
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-3
src/program.cpp
src/program.cpp
+12
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+54
-25
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+3
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+2
-4
No files found.
src/include/migraphx/operators.hpp
View file @
28f8b1e8
...
@@ -395,7 +395,6 @@ struct concat
...
@@ -395,7 +395,6 @@ struct concat
}
}
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
slice
struct
slice
...
@@ -698,8 +697,6 @@ struct gather
...
@@ -698,8 +697,6 @@ struct gather
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
dot
struct
dot
...
...
src/program.cpp
View file @
28f8b1e8
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
rep
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
}
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
outputs
().
empty
())
if
(
ins
->
outputs
().
empty
())
{
{
...
...
src/rewrite_rnn.cpp
View file @
28f8b1e8
...
@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{
{
0
}
}
,
concat_output
);
// add the dimension of num_direction
// The following logic is to ensure the last instruction rewritten from
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
// rnn operator is a concat instruction
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
// concat the forward and reverse output
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
}
else
else
{
{
...
@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
auto
ret
=
rnn_cell
(
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]
)
;
// add the dimension of num_direction
// following logic is to ensure the last instruction is a
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
}
}
}
...
@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
// if rnn operator is executed, the last_output != prog.end()
// if rnn operator is executed, the last_output != prog.end()
if
(
last_output
!=
prog
.
end
())
if
(
last_output
!=
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
prog
.
replace_instruction
(
ins
,
last_output
);
last_output
=
prog
.
end
();
last_output
=
prog
.
end
();
}
}
else
{
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: must put after rnn operator"
);
}
}
}
}
}
}
}
...
@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
}
instruction_ref
hidden_out
,
last_out
;
instruction_ref
hidden_out
=
prog
.
end
()
,
last_out
;
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
{
...
@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimension of sequence length
// add the dimensions of sequence length (axis 0 for sequence length,
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
ht
);
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
ht
);
if
(
is_forward
)
// concatenation for the last last_out is performed in the apply()
{
// function to ensure the last instruction is concat, then we have
hidden_out
=
(
seq_index
==
0
)
// output inserted
?
last_out
if
(
i
<
seq_len
-
1
)
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
if
(
is_forward
)
?
last_out
{
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
28f8b1e8
...
@@ -54,6 +54,7 @@ struct miopen_apply
...
@@ -54,6 +54,7 @@ struct miopen_apply
program
*
prog
=
nullptr
;
program
*
prog
=
nullptr
;
context
ctx
{};
context
ctx
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
void
check_shape
(
shape
x
,
instruction_ref
i
)
void
check_shape
(
shape
x
,
instruction_ref
i
)
{
{
...
@@ -64,6 +65,7 @@ struct miopen_apply
...
@@ -64,6 +65,7 @@ struct miopen_apply
void
init
()
void
init
()
{
{
this
->
last
=
instruction
::
get_output_alias
(
std
::
prev
(
prog
->
end
()));
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
...
@@ -112,7 +114,7 @@ struct miopen_apply
...
@@ -112,7 +114,7 @@ struct miopen_apply
instruction_ref
insert_allocation
(
instruction_ref
ins
,
const
shape
&
s
,
std
::
string
tag
=
""
)
instruction_ref
insert_allocation
(
instruction_ref
ins
,
const
shape
&
s
,
std
::
string
tag
=
""
)
{
{
if
(
ins
==
--
prog
->
end
()
and
tag
.
empty
())
if
(
ins
==
last
and
tag
.
empty
())
{
{
return
prog
->
add_parameter
(
"output"
,
s
);
return
prog
->
add_parameter
(
"output"
,
s
);
}
}
...
...
test/cpu_ops_test.cpp
View file @
28f8b1e8
...
@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
...
@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
TEST_CASE
(
rnn_bidirectional
)
TEST_CASE
(
rnn_bidirectional
)
{
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
2
;
std
::
size_t
seq_len
=
1
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
size_t
num_dirct
=
2
;
...
...
test/gpu/miopen.cpp
View file @
28f8b1e8
...
@@ -1084,7 +1084,6 @@ struct test_rnn_forward
...
@@ -1084,7 +1084,6 @@ struct test_rnn_forward
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
...
@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
r
,
r
,
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
...
@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1232,4 +1228,6 @@ int main()
...
@@ -1232,4 +1228,6 @@ int main()
verify_program
<
test_gather
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_reverse
>
();
verify_program
<
test_rnn_bidirectional
>
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment