Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0c212157
Commit
0c212157
authored
Jul 08, 2019
by
Paul
Browse files
Add pass to simplify mul conv
parent
3987066f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
12 deletions
+64
-12
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+5
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+41
-12
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+18
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
0c212157
...
@@ -369,6 +369,11 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
...
@@ -369,6 +369,11 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
auto
skip_output
(
Ms
...
ms
)
{
{
...
...
src/simplify_algebra.cpp
View file @
0c212157
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_add_
lit_broadcast
auto
lit_broadcast
()
{
{
auto
lit_broadcast
()
const
return
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
{
}
return
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
auto
not_lit_broadcast
()
}
{
auto
not_lit_broadcast
()
const
return
match
::
none_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
}
auto
op_lit_broadcast
(
std
::
string
op
,
std
::
string
x
,
std
::
string
y
)
{
return
match
::
name
(
op
)(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
}
struct
find_mul_conv
{
auto
matcher
()
const
{
{
return
match
::
none_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"conv"
)(
match
::
used_once
(),
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
))).
bind
(
"conv"
),
match
::
name
(
"broadcast"
).
bind
(
"a"
)));
}
}
auto
add_lit_broadcast
(
std
::
string
x
,
std
::
string
y
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
{
return
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
auto
ins
=
r
.
result
;
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
w_ins
=
r
.
instructions
[
"w"
];
auto
broadcast_op
=
any_cast
<
op
::
broadcast
>
(
a_ins
->
get_operator
());
if
(
broadcast_op
.
axis
!=
1
)
return
;
auto
new_a
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
0
,
w_ins
->
get_shape
().
lens
()},
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
new_a
,
w_ins
);
auto
new_conv
=
p
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_mul
);
p
.
replace_instruction
(
ins
,
new_conv
);
}
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
return
match
::
name
(
"add"
)(
match
::
args
(
add
_lit_broadcast
(
"a"
,
"x"
),
add
_lit_broadcast
(
"b"
,
"y"
)));
match
::
args
(
op
_lit_broadcast
(
"add"
,
"a"
,
"x"
),
op
_lit_broadcast
(
"add"
,
"b"
,
"y"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -59,7 +88,7 @@ struct find_add_lit_broadcast
...
@@ -59,7 +88,7 @@ struct find_add_lit_broadcast
}
}
};
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{});
}
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{}
,
find_mul_conv
{}
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
test/simplify_algebra_test.cpp
View file @
0c212157
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
...
@@ -99,6 +102,21 @@ TEST_CASE(simplify_add3)
...
@@ -99,6 +102,21 @@ TEST_CASE(simplify_add3)
EXPECT
(
p1
==
p2
);
EXPECT
(
p1
==
p2
);
}
}
TEST_CASE
(
simplify_mul_conv1
)
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
128
,
28
,
28
}});
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
,
128
,
3
,
3
}}));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{{
1
,
1
},{
2
,
2
},{
1
,
1
}},
x
,
w
);
auto
a
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
}}));
auto
b
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
{
1
,
256
,
14
,
14
}},
a
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
conv
,
b
);
p
.
add_instruction
(
pass_op
{},
mul
);
EXPECT
(
conv
->
outputs
().
front
()
->
name
()
==
"mul"
);
p
.
compile
(
simplify_algebra_target
{});
EXPECT
(
conv
->
outputs
().
front
()
->
name
()
!=
"mul"
);
}
// TODO: Add test case
// TODO: Add test case
void
simplify_add4
()
void
simplify_add4
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment