Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c46b5480
Commit
c46b5480
authored
Aug 12, 2019
by
Paul
Browse files
Make sure recursions are use only once
parent
efb704c5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
6 deletions
+41
-6
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+31
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+4
-3
No files found.
src/include/migraphx/matcher.hpp
View file @
c46b5480
...
@@ -469,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
...
@@ -469,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
});
}
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
s
,
xs
...});
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
{
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
...
...
src/simplify_algebra.cpp
View file @
c46b5480
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/mul.hpp>
...
@@ -19,7 +20,7 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
...
@@ -19,7 +20,7 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto
conv_const_weights
()
auto
conv_const_weights
()
{
{
return
match
::
name
(
"convolution"
)(
match
::
used_once
_recursive
(
4
),
return
match
::
name
(
"convolution"
)(
match
::
used_once
(
),
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
}
}
...
@@ -61,7 +62,7 @@ struct find_mul_add
...
@@ -61,7 +62,7 @@ struct find_mul_add
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"y"
)),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"y"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
_recursive
(
4
)),
match
::
used_once
(
)),
match
::
is_constant
().
bind
(
"a"
)));
match
::
is_constant
().
bind
(
"a"
)));
}
}
...
@@ -135,16 +136,43 @@ struct find_double_add_lit_broadcast
...
@@ -135,16 +136,43 @@ struct find_double_add_lit_broadcast
}
}
};
};
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"mul"
,
"add"
)(
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
return
;
auto
op
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
{
// Run simplifications multiple times
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
4
;
i
++
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
match
::
skip_matches
(
match
::
is_unused
(),
match
::
is_constant
())
,
find_inner_broadcast
{}
,
find_double_add_lit_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_mul_conv
{},
find_mul_conv
{},
find_mul_add
{});
find_mul_add
{});
dead_code_elimination
{}.
apply
(
p
);
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
c46b5480
...
@@ -265,6 +265,7 @@ struct find_add_relu
...
@@ -265,6 +265,7 @@ struct find_add_relu
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
used_once
(),
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
),
match
::
name
(
"hip::triadd"
),
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
(
match
::
name
(
"@literal"
),
...
@@ -294,7 +295,7 @@ struct find_triadd
...
@@ -294,7 +295,7 @@ struct find_triadd
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
name
(
"gpu::add"
)
(
match
::
used_once
())
.
bind
(
"add"
),
match
::
any
(
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any
(
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"input"
)));
.
bind
(
"input"
)));
...
@@ -325,7 +326,7 @@ struct find_mul_add
...
@@ -325,7 +326,7 @@ struct find_mul_add
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
)
(
match
::
used_once
())
.
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -349,7 +350,7 @@ struct find_mul_add_relu
...
@@ -349,7 +350,7 @@ struct find_mul_add_relu
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"hip::mul_add"
).
bind
(
"mul_add"
)));
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"hip::mul_add"
)
(
match
::
used_once
())
.
bind
(
"mul_add"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment