Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
015631a1
Unverified
Commit
015631a1
authored
Aug 28, 2019
by
mvermeulen
Committed by
GitHub
Aug 28, 2019
Browse files
Merge pull request #338 from ROCmSoftwarePlatform/eliminate-more-contiguous
Eliminate more contiguous
parents
a1c7e7a5
f1de9bc1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
3 deletions
+48
-3
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+10
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-0
test/eliminate_contiguous_test.cpp
test/eliminate_contiguous_test.cpp
+37
-3
No files found.
src/eliminate_contiguous.cpp
View file @
015631a1
...
...
@@ -4,6 +4,8 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <utility>
namespace
migraphx
{
...
...
@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
else
if
(
prev
->
can_eval
())
{
auto
c
=
op
::
contiguous
{};
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
015631a1
...
...
@@ -87,6 +87,7 @@ struct miopen_apply
void
init
()
{
this
->
last
=
instruction
::
get_output_alias
(
std
::
prev
(
prog
->
end
()));
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
add_miopen_extend_op
<
miopen_leaky_relu
,
op
::
leaky_relu
>
(
"leaky_relu"
,
make_leaky_relu
);
...
...
test/eliminate_contiguous_test.cpp
View file @
015631a1
...
...
@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE
(
standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_
literal
(
get_2x2
()
);
auto
l
=
p
.
add_
parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}}
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_standard_op
{},
c
);
...
...
@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
non_
standard_op
)
TEST_CASE
(
standard_op
_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_standard_op
{},
c
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
}
TEST_CASE
(
non_standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c
);
auto
count
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
non_standard_op_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
}
TEST_CASE
(
transpose_gemm
)
{
migraphx
::
program
p
;
...
...
@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE
(
transpose_standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_
literal
(
get_2x2
()
);
auto
l
=
p
.
add_
parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}}
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
sn
=
p
.
add_instruction
(
migraphx
::
op
::
sin
{},
c
);
...
...
@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
transpose_standard_op_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
sn
=
p
.
add_instruction
(
migraphx
::
op
::
sin
{},
c
);
p
.
add_instruction
(
pass_standard_op
{},
sn
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
3
);
}
TEST_CASE
(
no_packed_unary_op
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment