Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
41e236d3
Commit
41e236d3
authored
Oct 06, 2018
by
Paul
Browse files
Add matcher for fusable conv
parent
7dbc89d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
6 deletions
+34
-6
src/include/migraph/matcher.hpp
src/include/migraph/matcher.hpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+30
-2
No files found.
src/include/migraph/matcher.hpp
View file @
41e236d3
...
@@ -76,7 +76,7 @@ struct bindable_matcher
...
@@ -76,7 +76,7 @@ struct bindable_matcher
{
{
M
m
;
M
m
;
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
{
...
@@ -137,7 +137,7 @@ struct basic_matcher
...
@@ -137,7 +137,7 @@ struct basic_matcher
});
});
}
}
auto
bind
(
std
::
string
name
)
{
return
bind_match
(
m
,
name
);
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
)
)
;
}
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
instruction_ref
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
{
...
@@ -181,7 +181,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
...
@@ -181,7 +181,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{ \
{ \
bool operator()(__VA_ARGS__) const; \
bool operator()(__VA_ARGS__) const; \
}; \
}; \
const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \
const constexpr auto name = migraph::match::basic_matcher<
migraph::match::
predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
inline bool name##_m::operator()(__VA_ARGS__) const
struct
matcher_result
struct
matcher_result
...
@@ -310,7 +310,7 @@ auto args(Ms... ms)
...
@@ -310,7 +310,7 @@ auto args(Ms... ms)
});
});
}
}
auto
either_arg
(
std
::
size_t
i
,
std
::
size_t
j
)
inline
auto
either_arg
(
std
::
size_t
i
,
std
::
size_t
j
)
{
{
return
[
=
](
auto
m1
,
auto
m2
)
{
return
[
=
](
auto
m1
,
auto
m2
)
{
return
match
::
any_of
(
match
::
all_of
(
arg
(
i
)(
m1
),
arg
(
j
)(
m2
)),
return
match
::
any_of
(
match
::
all_of
(
arg
(
i
)(
m1
),
arg
(
j
)(
m2
)),
...
...
src/targets/gpu/fuse_ops.cpp
View file @
41e236d3
...
@@ -76,6 +76,34 @@ struct fusion
...
@@ -76,6 +76,34 @@ struct fusion
}
}
};
};
MIGRAPH_PRED_MATCHER
(
bias_shape
,
instruction_ref
ins
)
{
auto
&&
s
=
ins
->
get_shape
();
return
s
.
broadcasted
()
and
s
.
strides
().
size
()
==
4
and
s
.
strides
()[
0
]
==
0
and
s
.
strides
()[
1
]
!=
0
and
s
.
strides
()[
2
]
==
0
and
s
.
strides
()[
3
]
==
0
;
}
// TODO: Move to another header
template
<
class
T
,
class
...
Ts
>
std
::
array
<
T
,
sizeof
...(
Ts
)
+
1
>
make_array
(
T
x
,
Ts
...
xs
)
{
return
{
std
::
move
(
x
),
std
::
move
(
static_cast
<
T
>
(
xs
))...};
}
MIGRAPH_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::convolution"
)
return
false
;
auto
op
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
()).
op
;
return
op
.
padding
==
make_array
<
size_t
>
(
0
,
0
)
and
op
.
stride
==
make_array
<
size_t
>
(
1
,
1
)
and
op
.
dilation
==
make_array
<
size_t
>
(
1
,
1
);
}
struct
hip_add_relu
struct
hip_add_relu
{
{
std
::
string
name
()
const
{
return
"hip::add_relu"
;
}
std
::
string
name
()
const
{
return
"hip::add_relu"
;
}
...
@@ -168,17 +196,17 @@ struct match_conv_bias
...
@@ -168,17 +196,17 @@ struct match_conv_bias
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
broadc
as
t
_shape
().
bind
(
"bias"
),
match
::
name
(
"gpu::convolution"
).
bind
(
"conv"
)));
bi
as_shape
().
bind
(
"bias"
),
fusable_conv
(
).
bind
(
"conv"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
{
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
ins
=
r
.
result
;
auto
input_ins
=
conv_ins
->
inputs
().
at
(
0
);
auto
input_ins
=
conv_ins
->
inputs
().
at
(
0
);
auto
weights_ins
=
conv_ins
->
inputs
().
at
(
1
);
auto
weights_ins
=
conv_ins
->
inputs
().
at
(
1
);
auto
conv_op
=
any_cast
<
miopen_convolution
>
(
conv_ins
->
get_operator
()).
op
;
auto
conv_op
=
any_cast
<
miopen_convolution
>
(
conv_ins
->
get_operator
()).
op
;
auto
ins
=
r
.
result
;
auto
alloc_ins
=
ins
->
inputs
().
back
();
auto
alloc_ins
=
ins
->
inputs
().
back
();
auto
old_ws_ins
=
conv_ins
->
inputs
().
at
(
2
);
auto
old_ws_ins
=
conv_ins
->
inputs
().
at
(
2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment