Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f4c7809
Commit
7f4c7809
authored
Jun 30, 2019
by
Paul
Browse files
Formatting
parent
2e51006e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
67 deletions
+62
-67
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+18
-18
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+14
-19
test/matcher.cpp
test/matcher.cpp
+30
-30
No files found.
src/include/migraphx/matcher.hpp
View file @
7f4c7809
...
...
@@ -21,7 +21,7 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
...
...
@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
}
}
template
<
class
Op
,
bool
Start
,
bool
Matches
>
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
folder
{
template
<
class
...
Ts
>
...
...
@@ -257,7 +257,7 @@ struct folder
});
}
template
<
class
Selector
>
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
{
return
[
=
](
auto
...
ms
)
{
...
...
@@ -266,8 +266,8 @@ struct folder
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
());
})(
Start
,
ms
...));
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
());
})(
Start
,
ms
...));
});
if
(
matches
==
Matches
)
return
start
;
...
...
@@ -277,14 +277,14 @@ struct folder
}
};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
false
>
{};
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
...
...
@@ -292,7 +292,7 @@ inline auto inputs()
inline
auto
outputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
...
...
@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
...
...
@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
...
...
@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
...
...
@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
(
[
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
inline
auto
nargs
(
std
::
size_t
n
)
...
...
src/simplify_reshapes.cpp
View file @
7f4c7809
...
...
@@ -25,10 +25,7 @@ const auto& reshaper_names()
return
names
;
}
bool
is_reshaper
(
instruction_ref
ins
)
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
bool
is_reshaper
(
instruction_ref
ins
)
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
...
...
@@ -90,7 +87,8 @@ struct find_reshaper
{
auto
matcher
()
const
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
...
...
@@ -139,14 +137,15 @@ struct find_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
auto
t
=
ins
;
auto
x
=
ins
;
auto
t
=
ins
;
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
std
::
iota
(
dims
.
begin
(),
dims
.
end
(),
0
);
do
...
...
@@ -172,13 +171,14 @@ struct find_concat_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
...
...
@@ -187,10 +187,9 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
p
.
replace_instruction
(
ins
,
t
);
...
...
@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{}
);
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
}
}
...
...
test/matcher.cpp
View file @
7f4c7809
...
...
@@ -362,10 +362,10 @@ TEST_CASE(match_none_of2)
TEST_CASE
(
match_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -375,10 +375,10 @@ TEST_CASE(match_output1)
TEST_CASE
(
match_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -388,10 +388,10 @@ TEST_CASE(match_output2)
TEST_CASE
(
match_skip_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -401,11 +401,11 @@ TEST_CASE(match_skip_output1)
TEST_CASE
(
match_skip_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -415,13 +415,13 @@ TEST_CASE(match_skip_output2)
TEST_CASE
(
match_skip_output3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass1
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
minus_pass2
=
p
.
add_instruction
(
pass_op
{},
minus_pass1
);
auto
minus_pass3
=
p
.
add_instruction
(
pass_op
{},
minus_pass2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -431,10 +431,10 @@ TEST_CASE(match_skip_output3)
TEST_CASE
(
match_skip_output4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -444,8 +444,8 @@ TEST_CASE(match_skip_output4)
TEST_CASE
(
match_skip_output5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
...
...
@@ -459,12 +459,12 @@ TEST_CASE(match_skip_output5)
TEST_CASE
(
match_skip_output6
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
@@ -474,11 +474,11 @@ TEST_CASE(match_skip_output6)
TEST_CASE
(
match_skip_output7
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus1
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus2
=
p
.
add_instruction
(
minus_op
{},
two
,
minus1
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"minus"
)));
auto
r
=
find_match
(
p
,
m
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment