Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0c1ff20d
Commit
0c1ff20d
authored
Jul 05, 2019
by
Paul
Browse files
Fix error with nonstandard shapes
parent
5f31ae3f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
6 deletions
+25
-6
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+1
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+24
-6
No files found.
src/include/migraphx/matcher.hpp
View file @
0c1ff20d
...
@@ -330,6 +330,7 @@ inline auto outputs()
...
@@ -330,6 +330,7 @@ inline auto outputs()
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
not_standard_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
{
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
...
...
src/targets/gpu/fuse_ops.cpp
View file @
0c1ff20d
...
@@ -200,12 +200,28 @@ struct hip_add_relu
...
@@ -200,12 +200,28 @@ struct hip_add_relu
}
}
};
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
args
.
front
());
}
struct
find_add_relu
struct
find_add_relu
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
)).
bind
(
"add"
)));
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
)
,
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())
).
bind
(
"add"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -213,6 +229,9 @@ struct find_add_relu
...
@@ -213,6 +229,9 @@ struct find_add_relu
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
auto
args
=
add_ins
->
inputs
();
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
if
(
add_ins
->
name
()
==
"gpu::add"
)
if
(
add_ins
->
name
()
==
"gpu::add"
)
...
@@ -227,7 +246,7 @@ struct find_triadd
...
@@ -227,7 +246,7 @@ struct find_triadd
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
().
bind
(
"input"
)));
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())
).
bind
(
"input"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -242,10 +261,9 @@ struct find_triadd
...
@@ -242,10 +261,9 @@ struct find_triadd
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
return
;
args
.
insert
(
args
.
begin
(),
input_ins
);
args
.
insert
(
args
.
begin
(),
input_ins
);
// Ensure the last arguments is the broadcasted one
move_standard_front
(
args
);
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
);
move_broadcasted_back
(
args
);
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
args
.
back
()
=
ins
->
inputs
().
back
();
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment