Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3c45f2ed
Commit
3c45f2ed
authored
Feb 05, 2019
by
Paul
Browse files
Formatting
parent
9b8d62d1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+8
-8
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+2
-2
No files found.
src/simplify_reshapes.cpp
View file @
3c45f2ed
...
@@ -22,18 +22,18 @@ bool is_reshaper(instruction_ref ins)
...
@@ -22,18 +22,18 @@ bool is_reshaper(instruction_ref ins)
bool
is_transpose_output
(
instruction_ref
ins
)
bool
is_transpose_output
(
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
!=
1
)
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
{
if
(
ins
->
inputs
().
size
()
!=
1
)
if
(
ins
->
inputs
().
size
()
!=
1
)
return
ins
;
return
ins
;
if
(
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
)
if
(
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
)
return
find_transpose_input
(
ins
->
inputs
().
front
());
return
find_transpose_input
(
ins
->
inputs
().
front
());
if
(
ins
->
inputs
().
front
()
->
name
()
==
"transpose"
)
if
(
ins
->
inputs
().
front
()
->
name
()
==
"transpose"
)
return
ins
->
inputs
().
front
();
return
ins
->
inputs
().
front
();
...
@@ -47,7 +47,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -47,7 +47,7 @@ void simplify_reshapes::apply(program& p) const
{
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
continue
;
if
(
is_reshaper
(
ins
))
if
(
is_reshaper
(
ins
))
{
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
continue
;
...
@@ -78,9 +78,9 @@ void simplify_reshapes::apply(program& p) const
...
@@ -78,9 +78,9 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
else
if
(
ins
->
name
()
==
"transpose"
)
{
{
if
(
is_transpose_output
(
ins
))
if
(
is_transpose_output
(
ins
))
continue
;
continue
;
auto
x
=
ins
;
auto
x
=
ins
;
auto
t
=
ins
;
auto
t
=
ins
;
...
@@ -89,7 +89,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -89,7 +89,7 @@ void simplify_reshapes::apply(program& p) const
x
=
t
;
x
=
t
;
t
=
find_transpose_input
(
x
);
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
...
...
test/simplify_reshapes_test.cpp
View file @
3c45f2ed
...
@@ -142,7 +142,7 @@ TEST_CASE(transpose_contiguous)
...
@@ -142,7 +142,7 @@ TEST_CASE(transpose_contiguous)
auto
c1
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
c1
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c1
);
p
.
add_instruction
(
pass_op
{},
c1
);
auto
out_shape
=
p
.
get_shape
();
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
...
@@ -158,7 +158,7 @@ TEST_CASE(transpose_double_contiguous)
...
@@ -158,7 +158,7 @@ TEST_CASE(transpose_double_contiguous)
auto
c2
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
c1
);
auto
c2
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
c1
);
p
.
add_instruction
(
pass_op
{},
c2
);
p
.
add_instruction
(
pass_op
{},
c2
);
auto
out_shape
=
p
.
get_shape
();
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment