Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
28f8b1e8
Commit
28f8b1e8
authored
Jan 30, 2019
by
Shucai Xiao
Browse files
fixed issues for rnn operator.
parent
c2b69817
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
34 deletions
+72
-34
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-3
src/program.cpp
src/program.cpp
+12
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+54
-25
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+3
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+2
-4
No files found.
src/include/migraphx/operators.hpp
View file @
28f8b1e8
...
@@ -395,7 +395,6 @@ struct concat
...
@@ -395,7 +395,6 @@ struct concat
}
}
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
slice
struct
slice
...
@@ -698,8 +697,6 @@ struct gather
...
@@ -698,8 +697,6 @@ struct gather
return
result
;
return
result
;
}
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
struct
dot
struct
dot
...
...
src/program.cpp
View file @
28f8b1e8
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -134,6 +135,17 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
rep
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
assert
(
ins
!=
rep
);
if
(
ins
==
std
::
prev
(
this
->
end
()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
return
replace_instruction
(
ins
,
op
::
identity
{},
rep
);
}
}
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
outputs
().
empty
())
if
(
ins
->
outputs
().
empty
())
{
{
...
...
src/rewrite_rnn.cpp
View file @
28f8b1e8
...
@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -85,15 +85,22 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{
{
0
}
}
,
concat_output
);
// add the dimension of num_direction
// The following logic is to ensure the last instruction rewritten from
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
// rnn operator is a concat instruction
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
// concat the forward and reverse output
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
}
else
else
{
{
...
@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -125,10 +132,21 @@ void rewrite_rnn::apply(program& prog) const
auto
ret
=
rnn_cell
(
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]
)
;
// add the dimension of num_direction
// following logic is to ensure the last instruction is a
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
}
}
}
...
@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -141,9 +159,13 @@ void rewrite_rnn::apply(program& prog) const
// if rnn operator is executed, the last_output != prog.end()
// if rnn operator is executed, the last_output != prog.end()
if
(
last_output
!=
prog
.
end
())
if
(
last_output
!=
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
prog
.
replace_instruction
(
ins
,
last_output
);
last_output
=
prog
.
end
();
last_output
=
prog
.
end
();
}
}
else
{
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: must put after rnn operator"
);
}
}
}
}
}
}
}
...
@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -181,7 +203,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
}
instruction_ref
hidden_out
,
last_out
;
instruction_ref
hidden_out
=
prog
.
end
()
,
last_out
;
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
{
...
@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -205,20 +227,27 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimension of sequence length
// add the dimensions of sequence length (axis 0 for sequence length,
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
ht
);
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
,
1
}},
ht
);
if
(
is_forward
)
// concatenation for the last last_out is performed in the apply()
{
// function to ensure the last instruction is concat, then we have
hidden_out
=
(
seq_index
==
0
)
// output inserted
?
last_out
if
(
i
<
seq_len
-
1
)
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
if
(
is_forward
)
?
last_out
{
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last_out
,
hidden_out
);
}
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
28f8b1e8
...
@@ -54,6 +54,7 @@ struct miopen_apply
...
@@ -54,6 +54,7 @@ struct miopen_apply
program
*
prog
=
nullptr
;
program
*
prog
=
nullptr
;
context
ctx
{};
context
ctx
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
void
check_shape
(
shape
x
,
instruction_ref
i
)
void
check_shape
(
shape
x
,
instruction_ref
i
)
{
{
...
@@ -64,6 +65,7 @@ struct miopen_apply
...
@@ -64,6 +65,7 @@ struct miopen_apply
void
init
()
void
init
()
{
{
this
->
last
=
instruction
::
get_output_alias
(
std
::
prev
(
prog
->
end
()));
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
...
@@ -112,7 +114,7 @@ struct miopen_apply
...
@@ -112,7 +114,7 @@ struct miopen_apply
instruction_ref
insert_allocation
(
instruction_ref
ins
,
const
shape
&
s
,
std
::
string
tag
=
""
)
instruction_ref
insert_allocation
(
instruction_ref
ins
,
const
shape
&
s
,
std
::
string
tag
=
""
)
{
{
if
(
ins
==
--
prog
->
end
()
and
tag
.
empty
())
if
(
ins
==
last
and
tag
.
empty
())
{
{
return
prog
->
add_parameter
(
"output"
,
s
);
return
prog
->
add_parameter
(
"output"
,
s
);
}
}
...
...
test/cpu_ops_test.cpp
View file @
28f8b1e8
...
@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
...
@@ -1623,7 +1623,7 @@ TEST_CASE(rnn_reverse)
TEST_CASE
(
rnn_bidirectional
)
TEST_CASE
(
rnn_bidirectional
)
{
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
2
;
std
::
size_t
seq_len
=
1
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
size_t
num_dirct
=
2
;
...
...
test/gpu/miopen.cpp
View file @
28f8b1e8
...
@@ -1084,7 +1084,6 @@ struct test_rnn_forward
...
@@ -1084,7 +1084,6 @@ struct test_rnn_forward
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
...
@@ -1124,8 +1123,6 @@ struct test_rnn_reverse
r
,
r
,
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
...
@@ -1166,7 +1163,6 @@ struct test_rnn_bidirectional
bias
,
bias
,
ih
);
ih
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
auto
last
=
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
output
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
last
,
last
);
return
p
;
return
p
;
}
}
...
@@ -1232,4 +1228,6 @@ int main()
...
@@ -1232,4 +1228,6 @@ int main()
verify_program
<
test_gather
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_reverse
>
();
verify_program
<
test_rnn_bidirectional
>
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment