Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e32bbb45
Commit
e32bbb45
authored
Jan 14, 2019
by
Paul
Browse files
Fix bug in simplify reshapes
parent
96f7ae5b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
6 deletions
+66
-6
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+23
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+26
-6
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+17
-0
No files found.
src/include/migraphx/operators.hpp
View file @
e32bbb45
...
...
@@ -608,6 +608,29 @@ struct reshape
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
as_shape
{
shape
s
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
));
}
std
::
string
name
()
const
{
return
"as_shape"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
assert
(
inputs
.
front
().
elements
()
==
s
.
elements
());
return
s
;
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
dot
{
float
alpha
=
1.0
;
...
...
src/simplify_reshapes.cpp
View file @
e32bbb45
...
...
@@ -9,7 +9,18 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
is_reshaper
(
const
std
::
string
&
name
)
// Reshapers that can't handle nonstandard input shapes
bool
is_nonstandard_reshaper
(
instruction_ref
ins
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
and
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
;
}
bool
is_reshaper
(
instruction_ref
ins
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
...
...
@@ -19,26 +30,28 @@ bool is_reshaper(const std::string& name)
"contiguous"
};
// clang-format on
return
contains
(
names
,
name
);
return
contains
(
names
,
ins
->
name
())
and
not
is_nonstandard_reshaper
(
ins
);
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
not
is_reshaper
(
ins
->
name
()
))
if
(
not
is_reshaper
(
ins
))
continue
;
if
(
ins
->
outputs
().
size
()
!=
1
)
continue
;
if
(
is_reshaper
(
ins
->
outputs
().
front
()
->
name
()
))
if
(
is_reshaper
(
ins
->
outputs
().
front
()))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()
))
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
inputs
().
front
());
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
...
...
@@ -58,6 +71,13 @@ void simplify_reshapes::apply(program& p) const
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
// Replace all reshapes with as_shape
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"reshape"
)
continue
;
p
.
replace_instruction
(
ins
,
op
::
as_shape
{
ins
->
get_shape
()},
ins
->
inputs
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
test/simplify_reshapes_test.cpp
View file @
e32bbb45
...
...
@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass)
EXPECT
(
result
!=
get_2x2
());
}
TEST_CASE
(
reshape_transpose
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
112
,
56
,
56
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
r1
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
28
,
56
,
56
}},
x
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
,
3
,
4
}},
r1
);
auto
ct
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
r2
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
112
,
56
,
56
}},
ct
);
p
.
add_instruction
(
pass_op
{},
r2
);
EXPECT
(
p
.
get_shape
()
==
s
);
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
s
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment