Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9b8d62d1
"driver/vscode:/vscode.git/clone" did not exist on "aac345ab5677f5875bc904a31deb068588d463f0"
Commit
9b8d62d1
authored
Feb 05, 2019
by
Paul
Browse files
Fix reshape bugs with transpose
parent
15d3cf62
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
49 deletions
+108
-49
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+11
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+62
-38
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+1
-8
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+34
-3
No files found.
src/include/migraphx/operators.hpp
View file @
9b8d62d1
...
@@ -358,6 +358,17 @@ struct contiguous
...
@@ -358,6 +358,17 @@ struct contiguous
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
assert
(
output_shape
.
standard
());
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
result
;
}
};
};
struct
concat
struct
concat
...
...
src/simplify_reshapes.cpp
View file @
9b8d62d1
...
@@ -9,65 +9,89 @@
...
@@ -9,65 +9,89 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Reshapers that can't handle nonstandard input shapes
bool
is_nonstandard_reshaper
(
instruction_ref
ins
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
and
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
;
}
bool
is_reshaper
(
instruction_ref
ins
)
bool
is_reshaper
(
instruction_ref
ins
)
{
{
// clang-format off
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"reshape"
,
"transpose"
,
// "broadcast",
"contiguous"
"contiguous"
};
};
// clang-format on
// clang-format on
return
contains
(
names
,
ins
->
name
())
and
not
is_nonstandard_reshaper
(
ins
);
return
contains
(
names
,
ins
->
name
());
}
bool
is_transpose_output
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
size
()
!=
1
)
return
ins
;
if
(
ins
->
inputs
().
front
()
->
name
()
==
"contiguous"
)
return
find_transpose_input
(
ins
->
inputs
().
front
());
if
(
ins
->
inputs
().
front
()
->
name
()
==
"transpose"
)
return
ins
->
inputs
().
front
();
return
ins
;
}
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
not
is_reshaper
(
ins
))
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
ins
->
outputs
().
size
()
!=
1
)
continue
;
continue
;
if
(
is_reshaper
(
ins
->
outputs
().
front
()))
if
(
is_reshaper
(
ins
))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
continue
;
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
// Gather reshapes
reshapes
.
push_back
(
input
);
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
}
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
break
;
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
if
(
r
.
first
!=
r
.
second
)
else
if
(
ins
->
name
()
==
"transpose"
)
{
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
if
(
is_transpose_output
(
ins
))
continue
;
auto
x
=
ins
;
auto
t
=
ins
;
do
{
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
}
}
// Replace all reshapes with as_shape
// Replace all reshapes with as_shape
...
...
src/targets/cpu/lowering.cpp
View file @
9b8d62d1
...
@@ -287,14 +287,7 @@ struct cpu_contiguous
...
@@ -287,14 +287,7 @@ struct cpu_contiguous
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
output_shape
.
standard
());
return
op
.
compute
(
output_shape
,
args
);
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
result
;
}
}
};
};
...
...
test/simplify_reshapes_test.cpp
View file @
9b8d62d1
...
@@ -27,9 +27,9 @@ TEST_CASE(double_contig)
...
@@ -27,9 +27,9 @@ TEST_CASE(double_contig)
p
.
compile
(
simplify_reshapes_target
{});
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
4
);
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
EXPECT
(
result
=
=
get_2x2
());
EXPECT
(
result
!
=
get_2x2
());
}
}
TEST_CASE
(
double_transpose
)
TEST_CASE
(
double_transpose
)
...
@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass)
...
@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass)
p
.
compile
(
simplify_reshapes_target
{});
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
// std::cout << p << std::endl;
// TODO: Fix this
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
...
@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose)
...
@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
}
}
TEST_CASE
(
transpose_contiguous
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
x
);
auto
c1
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c1
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
);
}
TEST_CASE
(
transpose_double_contiguous
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
x
);
auto
c1
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
c2
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
c1
);
p
.
add_instruction
(
pass_op
{},
c2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
EXPECT
(
p
.
has_instruction
(
t
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment