Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2e51006e
Commit
2e51006e
authored
Jun 30, 2019
by
Paul
Browse files
Use matchers
parent
0fcf61e0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
366 additions
and
131 deletions
+366
-131
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+120
-45
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+120
-86
test/matcher.cpp
test/matcher.cpp
+126
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
2e51006e
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -20,6 +21,12 @@ struct matcher_context
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
private:
instruction_ref
last
;
instruction_ref
last
;
};
};
...
@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
...
@@ -205,12 +212,10 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
return
result
;
}
}
/// Find matches
in a
program
/// Find matches
for an instruction in the
program
template
<
class
...
Ms
>
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
...
@@ -223,46 +228,73 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -223,46 +228,73 @@ void find_matches(program& p, Ms&&... ms)
match
=
true
;
match
=
true
;
},
},
ms
...);
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
find_matches
(
p
,
ins
,
ms
...);
}
}
}
}
template
<
class
...
T
s
>
template
<
class
Op
,
bool
Start
,
bool
Matche
s
>
auto
all_of
(
Ts
...
ms
)
struct
folder
{
{
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
Op
op
;
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
()
)
;
})(
true
,
ms
...);
})(
Start
,
ms
...);
if
(
m
atches
)
if
(
matches
==
M
atches
)
return
ins
;
return
ins
;
return
ctx
.
not_found
();
return
ctx
.
not_found
();
});
});
}
}
template
<
class
...
Ts
>
template
<
class
Selector
>
auto
none_of
(
Ts
...
ms
)
auto
operator
[](
Selector
select
)
const
{
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
[
=
](
auto
...
ms
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
Op
op
;
})(
true
,
ms
...);
bool
matches
=
Start
;
if
(
matches
)
select
(
start
,
[
&
](
auto
ins
)
{
return
ins
;
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
());
})(
Start
,
ms
...));
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
return
ctx
.
not_found
();
});
});
};
}
};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
false
>
{};
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
}
template
<
class
...
Ts
>
inline
auto
outputs
()
auto
any_of
(
Ts
...
ms
)
{
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
return
[](
auto
ins
,
auto
f
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
f
(
x
);
})(
false
,
ms
...);
};
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
...
@@ -273,6 +305,21 @@ MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
...
@@ -273,6 +305,21 @@ MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
}
}
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
==
1
)
if
(
ins
->
outputs
().
size
()
==
1
)
...
@@ -289,10 +336,38 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
...
@@ -289,10 +336,38 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
inline
auto
name
(
std
::
string
name
)
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
name
s
=
std
::
move
(
name
s
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
()
)
>
0
;
});
}
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
...
...
src/simplify_reshapes.cpp
View file @
2e51006e
...
@@ -6,12 +6,13 @@
...
@@ -6,12 +6,13 @@
#include <migraphx/op/concat.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
#include <unordered_set>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
is_reshaper
(
instruction_ref
ins
)
const
auto
&
reshaper_names
(
)
{
{
// clang-format off
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
...
@@ -21,16 +22,12 @@ bool is_reshaper(instruction_ref ins)
...
@@ -21,16 +22,12 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
"unsqueeze"
};
};
// clang-format on
// clang-format on
return
contains
(
names
,
ins
->
name
())
;
return
names
;
}
}
bool
is_
transpose_output
(
instruction_ref
ins
)
bool
is_
reshaper
(
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
contains
(
reshaper_names
(),
ins
->
name
());
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
...
@@ -89,21 +86,16 @@ std::vector<int64_t> find_permutation(const shape& s)
...
@@ -89,21 +86,16 @@ std::vector<int64_t> find_permutation(const shape& s)
return
sort_permutation
(
s
.
strides
(),
std
::
greater
<>
{});
return
sort_permutation
(
s
.
strides
(),
std
::
greater
<>
{});
}
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
struct
find_reshaper
{
{
auto
end
=
std
::
prev
(
p
.
end
());
auto
matcher
()
const
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
continue
;
}
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
continue
;
if
(
is_reshaper
(
ins
))
{
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
auto
ins
=
mr
.
result
;
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
while
(
is_reshaper
(
reshapes
.
back
()))
{
{
...
@@ -130,10 +122,29 @@ void simplify_reshapes::apply(program& p) const
...
@@ -130,10 +122,29 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
};
MIGRAPHX_PRED_MATCHER
(
is_transpose_output
,
instruction_ref
start
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
self
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
})(
start
);
}
struct
find_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
{
if
(
is_transpose_output
(
ins
))
auto
ins
=
mr
.
result
;
continue
;
auto
x
=
ins
;
auto
x
=
ins
;
auto
t
=
ins
;
auto
t
=
ins
;
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
...
@@ -145,7 +156,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -145,7 +156,7 @@ void simplify_reshapes::apply(program& p) const
t
=
find_transpose_input
(
x
);
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
return
;
if
(
is_no_transpose
(
dims
))
if
(
is_no_transpose
(
dims
))
{
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
...
@@ -155,14 +166,20 @@ void simplify_reshapes::apply(program& p) const
...
@@ -155,14 +166,20 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
ins
,
op
::
transpose
{{
dims
}},
t
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
op
::
transpose
{{
dims
}},
t
->
inputs
().
front
());
}
}
}
}
else
if
(
ins
->
name
()
==
"concat"
)
};
struct
find_concat_transpose
{
auto
matcher
()
const
{
{
if
(
ins
->
inputs
().
empty
())
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
continue
;
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
none_of
(
ins
->
inputs
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
().
transposed
();
})
or
none_of
(
ins
->
inputs
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
==
s
;
}))
continue
;
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
auto
permutation
=
find_permutation
(
s
);
auto
ipermutaion
=
invert_permutation
(
permutation
);
auto
ipermutaion
=
invert_permutation
(
permutation
);
...
@@ -178,6 +195,23 @@ void simplify_reshapes::apply(program& p) const
...
@@ -178,6 +195,23 @@ void simplify_reshapes::apply(program& p) const
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
p
.
replace_instruction
(
ins
,
t
);
p
.
replace_instruction
(
ins
,
t
);
}
}
};
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{}
);
}
}
}
}
...
...
test/matcher.cpp
View file @
2e51006e
...
@@ -359,6 +359,132 @@ TEST_CASE(match_none_of2)
...
@@ -359,6 +359,132 @@ TEST_CASE(match_none_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass1
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
minus_pass2
=
p
.
add_instruction
(
pass_op
{},
minus_pass1
);
auto
minus_pass3
=
p
.
add_instruction
(
pass_op
{},
minus_pass2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
two
});
}
TEST_CASE
(
match_skip_output5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output6
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output7
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus1
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus2
=
p
.
add_instruction
(
minus_op
{},
two
,
minus1
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"minus"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus1
});
}
TEST_CASE
(
match_bind1
)
TEST_CASE
(
match_bind1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment