Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
254abcfa
Commit
254abcfa
authored
Sep 22, 2018
by
Paul
Browse files
Formatting
parent
e2e3606f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
112 deletions
+99
-112
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+4
-6
src/include/migraph/matcher.hpp
src/include/migraph/matcher.hpp
+55
-74
test/matcher.cpp
test/matcher.cpp
+40
-32
No files found.
src/include/migraph/functional.hpp
View file @
254abcfa
...
...
@@ -100,24 +100,22 @@ auto pack(Ts... xs)
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
>
auto
fold_impl
(
F
&&
,
T
&&
x
)
{
return
x
;
}
template
<
class
F
,
class
T
,
class
U
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
U
,
class
...
Ts
>
auto
fold_impl
(
F
&&
f
,
T
&&
x
,
U
&&
y
,
Ts
&&
...
xs
)
{
return
fold_impl
(
f
,
f
(
std
::
forward
<
T
>
(
x
),
std
::
forward
<
U
>
(
y
)),
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
F
>
template
<
class
F
>
auto
fold
(
F
f
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
}
// namespace migraph
...
...
src/include/migraph/matcher.hpp
View file @
254abcfa
...
...
@@ -12,19 +12,15 @@ namespace migraph {
struct
matcher_context
{
matcher_context
(
instruction_ref
i
)
:
last
(
i
)
{}
matcher_context
(
instruction_ref
i
)
:
last
(
i
)
{}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
instruction_ref
not_found
()
const
{
return
last
;
}
private:
instruction_ref
last
;
};
template
<
class
P
>
template
<
class
P
>
struct
predicate_matcher
{
P
p
;
...
...
@@ -38,7 +34,7 @@ struct predicate_matcher
}
};
template
<
class
F
>
template
<
class
F
>
struct
function_matcher
{
F
f
;
...
...
@@ -50,13 +46,13 @@ struct function_matcher
}
};
template
<
class
F
>
template
<
class
F
>
function_matcher
<
F
>
make_function_matcher
(
F
f
)
{
return
{
f
};
}
template
<
class
M
>
template
<
class
M
>
auto
bind_match
(
M
m
,
std
::
string
name
)
{
return
make_function_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
...
...
@@ -67,15 +63,12 @@ auto bind_match(M m, std::string name)
});
}
template
<
class
M
>
template
<
class
M
>
struct
bindable_matcher
{
M
m
;
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
...
...
@@ -83,19 +76,19 @@ struct bindable_matcher
}
};
template
<
class
M
>
template
<
class
M
>
bindable_matcher
<
M
>
make_bindable_matcher
(
M
m
)
{
return
{
m
};
}
template
<
class
F
>
template
<
class
F
>
bindable_matcher
<
function_matcher
<
F
>>
make_bf_matcher
(
F
f
)
{
return
{{
f
}};
}
template
<
class
F
>
template
<
class
F
>
bindable_matcher
<
predicate_matcher
<
F
>>
make_bp_matcher
(
F
f
)
{
return
{{
f
}};
...
...
@@ -105,18 +98,15 @@ using bool_list = std::initializer_list<bool>;
struct
id_matcher
{
instruction_ref
match
(
matcher_context
&
,
instruction_ref
ins
)
const
{
return
ins
;
}
instruction_ref
match
(
matcher_context
&
,
instruction_ref
ins
)
const
{
return
ins
;
}
};
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
{
M
m
;
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
// Copy m because we cant capture `this` by value
...
...
@@ -135,10 +125,7 @@ struct basic_matcher
});
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
...
...
@@ -146,41 +133,39 @@ struct basic_matcher
}
};
template
<
class
M
>
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
)
{
return
{
m
};
}
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
)
{
return
{{
f
}};
}
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
)
{
return
{{
p
}};
}
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name
##
_m \
{
\
struct name##_m
\
{
\
instruction_ref match(__VA_ARGS__) const; \
};
\
const constexpr auto name = migraph::basic_matcher<name
##
_m>{{}}; \
inline instruction_ref name
##
_m::match(__VA_ARGS__) const
};
\
const constexpr auto name = migraph::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name
##
_m \
{
\
struct name##_m
\
{
\
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::basic_matcher<predicate_matcher<name ## _m>>{{}}; \
inline bool name ## _m::operator()(__VA_ARGS__) const
}; \
const constexpr auto name = migraph::basic_matcher<predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct
matcher_result
{
...
...
@@ -188,7 +173,7 @@ struct matcher_result
instruction_ref
result
;
};
template
<
class
M
>
template
<
class
M
>
matcher_result
match_instruction
(
program
&
p
,
instruction_ref
ins
,
M
&&
m
)
{
assert
(
ins
!=
p
.
end
());
...
...
@@ -198,13 +183,13 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
}
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
x
,
xs
...};
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
bool
all_of_eager
(
Ts
...
xs
)
{
return
make_array
((
xs
,
true
)...)
==
make_array
(
static_cast
<
bool
>
(
xs
)...);
...
...
@@ -212,7 +197,7 @@ bool all_of_eager(Ts... xs)
namespace
matchers
{
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
all_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
...
...
@@ -225,7 +210,7 @@ auto all_of(Ts... ms)
});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
...
...
@@ -238,29 +223,23 @@ auto none_of(Ts... ms)
});
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
bool
matches
=
fold
(
[
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
MIGRAPH_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPH_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
inline
auto
name
(
std
::
string
name
)
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
}
inline
auto
arg
(
std
::
size_t
i
)
...
...
@@ -273,16 +252,18 @@ inline auto arg(std::size_t i)
}
// Workaround for bugs in clang
template
<
std
::
size_t
...>
struct
args_impl_ints
{};
template
<
std
::
size_t
...>
struct
args_impl_ints
{
};
template
<
std
::
size_t
...
Ns
,
class
...
Ms
>
template
<
std
::
size_t
...
Ns
,
class
...
Ms
>
auto
args_impl
(
args_impl_ints
<
Ns
...
>
,
Ms
...
ms
)
{
return
matchers
::
all_of
(
arg
(
Ns
)(
ms
)...);
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
args
(
Ms
...
ms
)
{
return
sequence_c
<
sizeof
...(
Ms
)
>
([
=
](
auto
...
is
)
{
...
...
test/matcher.cpp
View file @
254abcfa
...
...
@@ -5,11 +5,11 @@
namespace
matchers
=
migraph
::
matchers
;
template
<
class
M
>
template
<
class
M
>
migraph
::
matcher_result
find_match
(
migraph
::
program
&
p
,
M
&&
m
)
{
migraph
::
matcher_result
result
;
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
{
result
=
migraph
::
match_instruction
(
p
,
ins
,
m
);
if
(
result
.
result
!=
p
.
end
())
...
...
@@ -70,7 +70,8 @@ void match_arg1()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
...
...
@@ -82,7 +83,8 @@ void match_arg2()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
...
...
@@ -94,7 +96,8 @@ void match_arg3()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
...
...
@@ -106,7 +109,8 @@ void match_arg4()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
pass
});
}
...
...
@@ -118,7 +122,8 @@ void match_arg5()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
...
...
@@ -142,7 +147,8 @@ void match_arg7()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)));
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
...
...
@@ -154,12 +160,15 @@ void match_args1()
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
args
(
matchers
::
name
(
"@literal"
),
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
args
(
matchers
::
name
(
"@literal"
),
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
int
main
()
{
int
main
()
{
match1
();
match_name1
();
match_name2
();
...
...
@@ -174,5 +183,4 @@ int main() {
match_arg7
();
match_args1
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment