Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3b64f602
Commit
3b64f602
authored
Jul 01, 2019
by
Paul
Browse files
Fix tf test
parent
ea2f0cf4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
9 deletions
+132
-9
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+19
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+28
-6
test/shape_test.cpp
test/shape_test.cpp
+10
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+75
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
3b64f602
...
@@ -266,7 +266,7 @@ struct folder
...
@@ -266,7 +266,7 @@ struct folder
bool
matches
=
Start
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
select
(
start
,
[
&
](
auto
ins
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
=
=
ctx
.
not_found
());
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!
=
ctx
.
not_found
());
})(
Start
,
ms
...));
})(
Start
,
ms
...));
});
});
if
(
matches
==
Matches
)
if
(
matches
==
Matches
)
...
@@ -310,7 +310,7 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
...
@@ -310,7 +310,7 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
return
ins
->
get_shape
().
transposed
();
return
ins
->
get_shape
().
transposed
();
}
}
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
same_
input_
shapes
,
instruction_ref
ins
)
{
{
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
return
false
;
return
false
;
...
@@ -413,6 +413,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
...
@@ -413,6 +413,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
};
}
}
template
<
class
M
>
auto
same_shape
(
M
m
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
i
=
m
.
match
(
ctx
,
ins
);
if
(
i
!=
ctx
.
not_found
()
and
i
->
get_shape
()
==
ins
->
get_shape
())
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ms
>
auto
same_shape
(
Ms
...
ms
)
{
return
all_of
(
same_shape
(
ms
)...);
}
}
// namespace match
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/simplify_reshapes.cpp
View file @
3b64f602
...
@@ -122,6 +122,25 @@ struct find_reshaper
...
@@ -122,6 +122,25 @@ struct find_reshaper
}
}
};
};
struct
find_nop_reshapes
{
auto
matcher
()
const
{
auto
reshapes
=
reshaper_names
();
reshapes
.
insert
(
"transpose"
);
reshapes
.
insert
(
"slice"
);
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
))
);
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
auto
ins
=
mr
.
result
;
p
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
};
struct
find_transpose
struct
find_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -145,6 +164,7 @@ struct find_transpose
...
@@ -145,6 +164,7 @@ struct find_transpose
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
return
;
return
;
p
.
debug_print
();
if
(
is_no_transpose
(
dims
))
if
(
is_no_transpose
(
dims
))
{
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
...
@@ -160,7 +180,7 @@ struct find_concat_transpose
...
@@ -160,7 +180,7 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
return
match
::
name
(
"concat"
)(
match
::
same_
input_
shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
}
...
@@ -168,19 +188,21 @@ struct find_concat_transpose
...
@@ -168,19 +188,21 @@ struct find_concat_transpose
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
assert
(
s
.
transposed
());
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
auto
permutation
=
find_permutation
(
s
);
auto
ipermutaion
=
invert_permutation
(
permutation
);
auto
ipermuta
t
ion
=
invert_permutation
(
permutation
);
op
.
axis
=
i
permutaion
[
op
.
axis
];
op
.
axis
=
permuta
t
ion
[
op
.
axis
];
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
if
(
i
->
name
()
==
"transpose"
and
i
->
inputs
().
front
()
->
get_shape
().
standard
())
return
i
->
inputs
().
front
();
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
i
permutaion
},
concat
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permuta
t
ion
},
concat
);
p
.
replace_instruction
(
ins
,
t
);
p
.
replace_instruction
(
ins
,
t
);
}
}
};
};
...
@@ -195,7 +217,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -195,7 +217,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
continue
;
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
match
::
find_matches
(
p
,
ins
,
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
}
}
}
}
...
...
test/shape_test.cpp
View file @
3b64f602
...
@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
...
@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
not
s
.
broadcasted
());
}
}
TEST_CASE
(
test_shape_transposed
)
TEST_CASE
(
test_shape_transposed
1
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
standard
());
...
@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
...
@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
not
s
.
broadcasted
());
}
}
TEST_CASE
(
test_shape_transposed2
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
,
2
},
{
2
,
2
,
2
,
2
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_broadcasted
)
TEST_CASE
(
test_shape_broadcasted
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
...
...
test/simplify_reshapes_test.cpp
View file @
3b64f602
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
...
@@ -213,4 +214,78 @@ TEST_CASE(transpose_partial3)
...
@@ -213,4 +214,78 @@ TEST_CASE(transpose_partial3)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
}
}
TEST_CASE
(
nop_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
nop_transpose2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t2
);
auto
t4
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t3
);
p
.
add_instruction
(
pass_op
{},
t4
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
4
);
}
TEST_CASE
(
nop_transpose3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
x
,
y
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
,
3
}},
concat
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
concat_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
2
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment