Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db4bc970
Commit
db4bc970
authored
Apr 22, 2019
by
Shucai Xiao
Browse files
change the algorithm of eliminate contiguous.
parent
ce8139e5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
6 deletions
+41
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+38
-3
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+2
-2
src/targets/gpu/include/migraphx/gpu/oper.hpp
src/targets/gpu/include/migraphx/gpu/oper.hpp
+1
-1
No files found.
src/eliminate_contiguous.cpp
View file @
db4bc970
...
...
@@ -9,19 +9,54 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
try_compute_shape
(
co
nst
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
arg
s
)
bool
try_compute_shape
(
i
nst
ruction_ref
ins
,
const
std
::
vector
<
shape
>&
input
s
)
{
try
{
compute_shape
(
op
,
args
);
shape
new_shape
=
ins
->
get_operator
().
compute_shape
(
inputs
);
// If the output shape is a standard shape, no need to try its output
if
(
new_shape
.
standard
())
{
return
true
;
}
auto
outputs
=
ins
->
outputs
();
// If the current instruction has no output, it means the last output shape
// is non-standard, then we cannot eliminate the contiguous
if
(
outputs
.
empty
())
{
return
false
;
}
for
(
auto
output
:
outputs
)
{
auto
args
=
output
->
inputs
();
std
::
vector
<
shape
>
input_shapes
;
for
(
auto
arg
:
args
)
{
input_shapes
.
push_back
((
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
());
}
if
(
!
try_compute_shape
(
output
,
input_shapes
))
{
return
false
;
}
}
}
catch
(...)
{
return
false
;
}
return
true
;
}
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
auto
inputs
=
to_shapes
(
args
);
return
try_compute_shape
(
ins
,
inputs
);
}
void
eliminate_contiguous
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
...
...
@@ -44,7 +79,7 @@ void eliminate_contiguous::apply(program& p) const
auto
new_args
=
args
;
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
->
get_operator
()
,
new_args
))
if
(
try_compute_shape
(
ins
,
new_args
))
{
instruction
::
replace_argument
(
ins
,
arg
,
prev
);
}
...
...
src/targets/cpu/lowering.cpp
View file @
db4bc970
...
...
@@ -795,7 +795,7 @@ struct cpu_binary
std
::
string
name
()
const
{
return
op
.
name
();
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
inputs
.
at
(
0
)
==
inputs
.
at
(
1
)
and
inputs
.
at
(
0
).
packed
()
and
inputs
.
at
(
1
).
packed
()
)
if
(
inputs
.
at
(
0
)
==
inputs
.
at
(
1
)
and
inputs
.
at
(
0
).
packed
())
{
return
inputs
.
at
(
0
);
}
...
...
@@ -811,7 +811,7 @@ struct cpu_binary
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
auto
s1
=
input1
.
get_shape
();
auto
s2
=
input2
.
get_shape
();
if
(
s1
==
s2
and
s1
.
packed
()
and
s2
.
packed
()
)
if
(
s1
==
s2
and
s1
.
packed
())
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input2
.
begin
(),
output
.
begin
(),
op
.
fcn
());
...
...
src/targets/gpu/include/migraphx/gpu/oper.hpp
View file @
db4bc970
...
...
@@ -70,7 +70,7 @@ struct binary_device : oper<Derived>
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
if
(
inputs
.
at
(
0
)
==
inputs
.
at
(
1
)
and
inputs
.
at
(
0
).
packed
()
and
inputs
.
at
(
1
).
packed
()
)
if
(
inputs
.
at
(
0
)
==
inputs
.
at
(
1
)
and
inputs
.
at
(
0
).
packed
())
{
return
inputs
.
at
(
0
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment