Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e2e3606f
Commit
e2e3606f
authored
Sep 22, 2018
by
Paul
Browse files
Add initial matchers
parent
fd6f1cc6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
508 additions
and
0 deletions
+508
-0
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+32
-0
src/include/migraph/matcher.hpp
src/include/migraph/matcher.hpp
+298
-0
test/matcher.cpp
test/matcher.cpp
+178
-0
No files found.
src/include/migraph/functional.hpp
View file @
e2e3606f
...
...
@@ -61,6 +61,12 @@ constexpr void repeat_c_impl(F f, seq<Ns...>)
swallow
{(
f
(
std
::
integral_constant
<
std
::
size_t
,
Ns
>
{}),
0
)...};
}
template
<
class
F
,
std
::
size_t
...
Ns
>
constexpr
auto
sequence_c_impl
(
F
&&
f
,
seq
<
Ns
...
>
)
{
return
f
(
std
::
integral_constant
<
std
::
size_t
,
Ns
>
{}...);
}
}
// namespace detail
template
<
std
::
size_t
N
,
class
F
>
...
...
@@ -69,6 +75,12 @@ constexpr void repeat_c(F f)
detail
::
repeat_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
template
<
std
::
size_t
N
,
class
F
>
constexpr
auto
sequence_c
(
F
&&
f
)
{
return
detail
::
sequence_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
@@ -88,6 +100,26 @@ auto pack(Ts... xs)
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
template
<
class
F
,
class
T
>
auto
fold_impl
(
F
&&
,
T
&&
x
)
{
return
x
;
}
template
<
class
F
,
class
T
,
class
U
,
class
...
Ts
>
auto
fold_impl
(
F
&&
f
,
T
&&
x
,
U
&&
y
,
Ts
&&
...
xs
)
{
return
fold_impl
(
f
,
f
(
std
::
forward
<
T
>
(
x
),
std
::
forward
<
U
>
(
y
)),
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
F
>
auto
fold
(
F
f
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
}
// namespace migraph
#endif
src/include/migraph/matcher.hpp
0 → 100644
View file @
e2e3606f
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/program.hpp>
#include <migraph/type_name.hpp>
#include <unordered_map>
namespace
migraph
{
struct
matcher_context
{
matcher_context
(
instruction_ref
i
)
:
last
(
i
)
{}
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
private:
instruction_ref
last
;
};
template
<
class
P
>
struct
predicate_matcher
{
P
p
;
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
assert
(
ins
!=
ctx
.
not_found
());
if
(
p
(
ins
))
return
ins
;
return
ctx
.
not_found
();
}
};
template
<
class
F
>
struct
function_matcher
{
F
f
;
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
assert
(
ins
!=
ctx
.
not_found
());
return
f
(
ctx
,
ins
);
}
};
template
<
class
F
>
function_matcher
<
F
>
make_function_matcher
(
F
f
)
{
return
{
f
};
}
template
<
class
M
>
auto
bind_match
(
M
m
,
std
::
string
name
)
{
return
make_function_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
);
return
result
;
});
}
template
<
class
M
>
struct
bindable_matcher
{
M
m
;
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
};
template
<
class
M
>
bindable_matcher
<
M
>
make_bindable_matcher
(
M
m
)
{
return
{
m
};
}
template
<
class
F
>
bindable_matcher
<
function_matcher
<
F
>>
make_bf_matcher
(
F
f
)
{
return
{{
f
}};
}
template
<
class
F
>
bindable_matcher
<
predicate_matcher
<
F
>>
make_bp_matcher
(
F
f
)
{
return
{{
f
}};
}
using
bool_list
=
std
::
initializer_list
<
bool
>
;
struct
id_matcher
{
instruction_ref
match
(
matcher_context
&
,
instruction_ref
ins
)
const
{
return
ins
;
}
};
template
<
class
M
>
struct
basic_matcher
{
M
m
;
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
result
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
result
;
}
return
ctx
.
not_found
();
});
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
};
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
)
{
return
{
m
};
}
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
)
{
return
{{
f
}};
}
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
)
{
return
{{
p
}};
}
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name ## _m \
{ \
instruction_ref match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::basic_matcher<name ## _m>{{}}; \
inline instruction_ref name ## _m::match(__VA_ARGS__) const
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name ## _m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::basic_matcher<predicate_matcher<name ## _m>>{{}}; \
inline bool name ## _m::operator()(__VA_ARGS__) const
struct
matcher_result
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
result
;
};
template
<
class
M
>
matcher_result
match_instruction
(
program
&
p
,
instruction_ref
ins
,
M
&&
m
)
{
assert
(
ins
!=
p
.
end
());
matcher_result
result
;
matcher_context
ctx
{
p
.
end
()};
result
.
result
=
m
.
match
(
ctx
,
ins
);
return
result
;
}
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
x
,
xs
...};
}
template
<
class
...
Ts
>
bool
all_of_eager
(
Ts
...
xs
)
{
return
make_array
((
xs
,
true
)...)
==
make_array
(
static_cast
<
bool
>
(
xs
)...);
}
namespace
matchers
{
template
<
class
...
Ts
>
auto
all_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
MIGRAPH_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
inline
auto
name
(
std
::
string
name
)
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
}
inline
auto
arg
(
std
::
size_t
i
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
i
<
ins
->
inputs
().
size
())
return
ins
->
inputs
()[
i
];
return
ctx
.
not_found
();
});
}
// Workaround for bugs in clang
template
<
std
::
size_t
...>
struct
args_impl_ints
{};
template
<
std
::
size_t
...
Ns
,
class
...
Ms
>
auto
args_impl
(
args_impl_ints
<
Ns
...
>
,
Ms
...
ms
)
{
return
matchers
::
all_of
(
arg
(
Ns
)(
ms
)...);
}
template
<
class
...
Ms
>
auto
args
(
Ms
...
ms
)
{
return
sequence_c
<
sizeof
...(
Ms
)
>
([
=
](
auto
...
is
)
{
// It needs to be written as `decltype(is)::value` for gcc 5
return
args_impl
(
args_impl_ints
<
decltype
(
is
)
::
value
...
>
{},
ms
...);
});
}
}
// namespace matchers
}
// namespace migraph
#endif
test/matcher.cpp
0 → 100644
View file @
e2e3606f
#include <migraph/matcher.hpp>
#include <migraph/iterator_for.hpp>
#include <test.hpp>
#include <basic_ops.hpp>
namespace
matchers
=
migraph
::
matchers
;
template
<
class
M
>
migraph
::
matcher_result
find_match
(
migraph
::
program
&
p
,
M
&&
m
)
{
migraph
::
matcher_result
result
;
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
{
result
=
migraph
::
match_instruction
(
p
,
ins
,
m
);
if
(
result
.
result
!=
p
.
end
())
return
result
;
}
return
result
;
}
void
match1
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
1
);
auto
m
=
matchers
::
standard_shape
();
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
l
});
}
void
match_name1
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
);
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_name2
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"min"
);
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_name3
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg1
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg2
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_arg3
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg4
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
pass
});
}
void
match_arg5
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"pass"
)(
matchers
::
arg
(
1
)(
matchers
::
name
(
"sum"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_arg6
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg7
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
arg
(
0
)(
matchers
::
name
(
"@literal"
)),
matchers
::
arg
(
1
)(
matchers
::
name
(
"@literal"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_args1
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
matchers
::
name
(
"sum"
)(
matchers
::
args
(
matchers
::
name
(
"@literal"
),
matchers
::
name
(
"@literal"
)),
matchers
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
int
main
()
{
match1
();
match_name1
();
match_name2
();
match_name3
();
match_arg1
();
match_arg2
();
match_arg3
();
match_arg4
();
match_arg5
();
match_arg6
();
match_arg7
();
match_args1
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment