Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2e51006e
"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "95390dfef7396e835e5dc2578afcf6cddd12f424"
Commit
2e51006e
authored
Jun 30, 2019
by
Paul
Browse files
Use matchers
parent
0fcf61e0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
366 additions
and
131 deletions
+366
-131
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+120
-45
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+120
-86
test/matcher.cpp
test/matcher.cpp
+126
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
2e51006e
...
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
instruction_ref
last
;
};
...
...
@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
}
/// Find matches
in a
program
/// Find matches
for an instruction in the
program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
...
...
@@ -223,46 +228,73 @@ void find_matches(program& p, Ms&&... ms)
match
=
true
;
},
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
find_matches
(
p
,
ins
,
ms
...);
}
}
template
<
class
...
T
s
>
auto
all_of
(
Ts
...
ms
)
template
<
class
Op
,
bool
Start
,
bool
Matche
s
>
struct
folder
{
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
Op
op
;
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
m
atches
)
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
()
)
;
})(
Start
,
ms
...);
if
(
matches
==
M
atches
)
return
ins
;
return
ctx
.
not_found
();
});
}
}
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
{
return
[
=
](
auto
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
());
})(
Start
,
ms
...));
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
});
};
}
};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
false
>
{};
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
inline
auto
outputs
()
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
false
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
...
...
@@ -273,6 +305,21 @@ MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
...
...
@@ -289,10 +336,38 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
}
inline
auto
name
(
std
::
string
name
)
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
name
s
=
std
::
move
(
name
s
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
()
)
>
0
;
});
}
inline
auto
nargs
(
std
::
size_t
n
)
...
...
src/simplify_reshapes.cpp
View file @
2e51006e
...
...
@@ -6,12 +6,13 @@
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
is_reshaper
(
instruction_ref
ins
)
const
auto
&
reshaper_names
(
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
...
...
@@ -21,16 +22,12 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
;
return
names
;
}
bool
is_
transpose_output
(
instruction_ref
ins
)
bool
is_
reshaper
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
return
contains
(
reshaper_names
(),
ins
->
name
());
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
...
...
@@ -89,21 +86,16 @@ std::vector<int64_t> find_permutation(const shape& s)
return
sort_permutation
(
s
.
strides
(),
std
::
greater
<>
{});
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
struct
find_reshaper
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
auto
matcher
()
const
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
is_reshaper
(
ins
))
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
// Gather reshapes
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
...
...
@@ -130,10 +122,29 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
};
MIGRAPHX_PRED_MATCHER
(
is_transpose_output
,
instruction_ref
start
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
self
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
})(
start
);
}
struct
find_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
if
(
is_transpose_output
(
ins
))
continue
;
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
auto
t
=
ins
;
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
...
...
@@ -145,7 +156,7 @@ void simplify_reshapes::apply(program& p) const
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
return
;
if
(
is_no_transpose
(
dims
))
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
...
...
@@ -155,14 +166,20 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
ins
,
op
::
transpose
{{
dims
}},
t
->
inputs
().
front
());
}
}
else
if
(
ins
->
name
()
==
"concat"
)
};
struct
find_concat_transpose
{
auto
matcher
()
const
{
if
(
ins
->
inputs
().
empty
())
continue
;
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
none_of
(
ins
->
inputs
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
().
transposed
();
})
or
none_of
(
ins
->
inputs
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
==
s
;
}))
continue
;
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
auto
ipermutaion
=
invert_permutation
(
permutation
);
...
...
@@ -178,6 +195,23 @@ void simplify_reshapes::apply(program& p) const
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
p
.
replace_instruction
(
ins
,
t
);
}
};
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{}
);
}
}
...
...
test/matcher.cpp
View file @
2e51006e
...
...
@@ -359,6 +359,132 @@ TEST_CASE(match_none_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass1
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
minus_pass2
=
p
.
add_instruction
(
pass_op
{},
minus_pass1
);
auto
minus_pass3
=
p
.
add_instruction
(
pass_op
{},
minus_pass2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
two
});
}
TEST_CASE
(
match_skip_output5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output6
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output7
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus1
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus2
=
p
.
add_instruction
(
minus_op
{},
two
,
minus1
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"minus"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus1
});
}
TEST_CASE
(
match_bind1
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment