Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
41e236d3
Commit
41e236d3
authored
Oct 06, 2018
by
Paul
Browse files
Add matcher for fusable conv
parent
7dbc89d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
6 deletions
+34
-6
src/include/migraph/matcher.hpp
src/include/migraph/matcher.hpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+30
-2
No files found.
src/include/migraph/matcher.hpp
View file @
41e236d3
...
...
@@ -76,7 +76,7 @@ struct bindable_matcher
{
M
m
;
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
...
...
@@ -137,7 +137,7 @@ struct basic_matcher
});
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
)
)
;
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
...
...
@@ -181,7 +181,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \
const constexpr auto name = migraph::match::basic_matcher<
migraph::match::
predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct
matcher_result
...
...
@@ -310,7 +310,7 @@ auto args(Ms... ms)
});
}
auto
either_arg
(
std
::
size_t
i
,
std
::
size_t
j
)
inline
auto
either_arg
(
std
::
size_t
i
,
std
::
size_t
j
)
{
return
[
=
](
auto
m1
,
auto
m2
)
{
return
match
::
any_of
(
match
::
all_of
(
arg
(
i
)(
m1
),
arg
(
j
)(
m2
)),
...
...
src/targets/gpu/fuse_ops.cpp
View file @
41e236d3
...
...
@@ -76,6 +76,34 @@ struct fusion
}
};
MIGRAPH_PRED_MATCHER
(
bias_shape
,
instruction_ref
ins
)
{
auto
&&
s
=
ins
->
get_shape
();
return
s
.
broadcasted
()
and
s
.
strides
().
size
()
==
4
and
s
.
strides
()[
0
]
==
0
and
s
.
strides
()[
1
]
!=
0
and
s
.
strides
()[
2
]
==
0
and
s
.
strides
()[
3
]
==
0
;
}
// TODO: Move to another header
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
std
::
move
(
x
),
std
::
move
(
static_cast
<
T
>
(
xs
))...};
}
MIGRAPH_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::convolution"
)
return
false
;
auto
op
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
()).
op
;
return
op
.
padding
==
make_array
<
size_t
>
(
0
,
0
)
and
op
.
stride
==
make_array
<
size_t
>
(
1
,
1
)
and
op
.
dilation
==
make_array
<
size_t
>
(
1
,
1
);
}
struct
hip_add_relu
{
std
::
string
name
()
const
{
return
"hip::add_relu"
;
}
...
...
@@ -168,17 +196,17 @@ struct match_conv_bias
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
broadc
as
t
_shape
().
bind
(
"bias"
),
match
::
name
(
"gpu::convolution"
).
bind
(
"conv"
)));
bi
as_shape
().
bind
(
"bias"
),
fusable_conv
(
).
bind
(
"conv"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
ins
=
r
.
result
;
auto
input_ins
=
conv_ins
->
inputs
().
at
(
0
);
auto
weights_ins
=
conv_ins
->
inputs
().
at
(
1
);
auto
conv_op
=
any_cast
<
miopen_convolution
>
(
conv_ins
->
get_operator
()).
op
;
auto
ins
=
r
.
result
;
auto
alloc_ins
=
ins
->
inputs
().
back
();
auto
old_ws_ins
=
conv_ins
->
inputs
().
at
(
2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment