Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
03afd1ed
Commit
03afd1ed
authored
Aug 26, 2019
by
Paul
Browse files
Fix matcher bugs in either_arg
parent
43a96492
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
227 additions
and
6 deletions
+227
-6
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+1
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+7
-5
test/matcher.cpp
test/matcher.cpp
+193
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+26
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
03afd1ed
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
return
result
;
});
});
}
}
...
...
src/simplify_algebra.cpp
View file @
03afd1ed
...
@@ -52,6 +52,7 @@ struct find_mul_conv
...
@@ -52,6 +52,7 @@ struct find_mul_conv
}
}
};
};
// a * (x + b) => a * x + a * b
struct
find_mul_add
struct
find_mul_add
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -60,7 +61,7 @@ struct find_mul_add
...
@@ -60,7 +61,7 @@ struct find_mul_add
match
::
name
(
"add"
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"
y
"
)),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"
b
"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
match
::
is_constant
().
bind
(
"a"
)));
...
@@ -70,12 +71,13 @@ struct find_mul_add
...
@@ -70,12 +71,13 @@ struct find_mul_add
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
a
uto
y
_ins
=
r
.
instructions
[
"y"
]
;
a
ssert
(
x
_ins
!
=
b_ins
)
;
auto
x
a_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
x
_ins
,
a
_ins
);
auto
a
x
_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a
_ins
,
x
_ins
);
auto
y
a_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
y
_ins
,
a
_ins
);
auto
a
b
_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a
_ins
,
b
_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
x
a_ins
,
y
a_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
a
x
_ins
,
a
b
_ins
);
}
}
};
};
...
...
test/matcher.cpp
View file @
03afd1ed
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
namespace
match
=
migraphx
::
match
;
namespace
match
=
migraphx
::
match
;
MIGRAPHX_PRED_MATCHER
(
throws
,
migraphx
::
instruction_ref
)
{
MIGRAPHX_THROW
(
"Matcher throws"
);
}
template
<
class
M
>
template
<
class
M
>
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
{
{
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_either_args_any1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"@literal"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"sum"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"sum"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_all_of1
)
TEST_CASE
(
match_all_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
bool
{
r
.
result
==
sum
});
}
}
TEST_CASE
(
match_lazy_any_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
any_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
one
});
}
TEST_CASE
(
match_lazy_all_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
all_of
(
match
::
none
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_lazy_none_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
none_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_any_of1
)
TEST_CASE
(
match_any_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -396,6 +503,92 @@ TEST_CASE(match_any_of2)
...
@@ -396,6 +503,92 @@ TEST_CASE(match_any_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_any_of_lazy1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"sum"
),
match
::
name
(
"sum"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"x"
),
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x1"
),
match
::
name
(
"@literal"
).
bind
(
"y1"
)),
match
::
args
(
match
::
any
().
bind
(
"x2"
),
match
::
any
().
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_any_of_lazy5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
().
bind
(
"x1"
),
match
::
any
().
bind
(
"y1"
)),
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x2"
),
match
::
name
(
"@literal"
).
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_none_of1
)
TEST_CASE
(
match_none_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/simplify_algebra_test.cpp
View file @
03afd1ed
...
@@ -150,4 +150,30 @@ TEST_CASE(simplify_mul_conv1)
...
@@ -150,4 +150,30 @@ TEST_CASE(simplify_mul_conv1)
EXPECT
(
new_conv
->
outputs
().
front
()
->
name
()
!=
"mul"
);
EXPECT
(
new_conv
->
outputs
().
front
()
->
name
()
!=
"mul"
);
}
}
TEST_CASE
(
simplify_mul_add
)
{
migraphx
::
program
p1
;
{
auto
x
=
p1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
2
);
auto
sum
=
p1
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
x
);
auto
mul
=
p1
.
add_instruction
(
migraphx
::
op
::
mul
{},
sum
,
two
);
p1
.
add_instruction
(
pass_op
{},
mul
);
}
p1
.
compile
(
simplify_algebra_target
{});
migraphx
::
program
p2
;
{
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
mul1
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
x
);
auto
mul2
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
one
);
auto
sum
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
mul1
,
mul2
);
p2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
p1
==
p2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment