Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0dae73fa
Unverified
Commit
0dae73fa
authored
Jun 12, 2023
by
Paul Fultz II
Committed by
GitHub
Jun 12, 2023
Browse files
Enable reshape on nonstandard shapes (#1681)
parent
c900e382
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
5 deletions
+230
-5
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+117
-5
test/op_shape_test.cpp
test/op_shape_test.cpp
+113
-0
No files found.
src/include/migraphx/op/reshape.hpp
View file @
0dae73fa
...
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -96,9 +97,115 @@ struct reshape
return
{
s0
.
type
(),
output_dyn_dims
};
}
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static
optional
<
shape
>
reshape_dims
(
const
shape
&
input
,
const
std
::
vector
<
std
::
size_t
>&
rdims
)
{
if
(
input
.
standard
())
return
shape
{
input
.
type
(),
rdims
};
const
auto
&
idims
=
input
.
lens
();
const
auto
&
istrides
=
input
.
strides
();
std
::
vector
<
std
::
size_t
>
rstrides
;
std
::
size_t
i
=
0
;
std
::
size_t
r
=
0
;
while
(
i
<
idims
.
size
()
and
r
<
rdims
.
size
())
{
auto
idim
=
idims
[
i
];
auto
rdim
=
rdims
[
r
];
if
(
rdim
==
idim
)
{
rstrides
.
push_back
(
istrides
[
i
]);
}
// squeeze
else
if
(
rdim
>
idim
)
{
auto
start
=
idims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
idims
.
end
(),
rdim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
i
+
n
)
<=
istrides
.
size
());
if
(
not
can_strides_merge
(
start
,
it
+
1
,
istrides
.
begin
()
+
i
,
istrides
.
begin
()
+
i
+
n
+
1
))
return
nullopt
;
i
+=
n
;
rstrides
.
push_back
(
istrides
[
i
]);
}
// unsqueeze
else
// if(rdim < idim)
{
auto
start
=
rdims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
rdims
.
end
(),
idim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
r
+
n
)
<=
rdims
.
size
());
auto
stride
=
istrides
[
i
]
*
idim
;
std
::
for_each
(
start
,
it
+
1
,
[
&
](
auto
dim
)
{
stride
/=
dim
;
rstrides
.
push_back
(
stride
);
});
r
+=
n
;
}
i
++
;
r
++
;
}
// Handle trailing 1s
if
(
rstrides
.
size
()
<
rdims
.
size
()
and
not
rstrides
.
empty
())
{
auto
stride
=
rstrides
.
back
();
for
(
auto
d
:
range
(
rdims
.
begin
()
+
rstrides
.
size
(),
rdims
.
end
()))
{
if
(
d
!=
1
)
return
nullopt
;
rstrides
.
push_back
(
stride
);
}
}
if
(
rdims
.
size
()
!=
rstrides
.
size
())
return
nullopt
;
return
shape
{
input
.
type
(),
rdims
,
rstrides
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
check_shapes
{
inputs
,
*
this
}.
standard
(
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
...
...
@@ -125,12 +232,17 @@ struct reshape
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
auto
s
=
reshape_dims
(
inputs
.
front
(),
rdims
);
if
(
not
s
.
has_value
())
MIGRAPHX_THROW
(
"Reshape on axis that is not packed."
);
if
(
s
->
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
s
->
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
inputs
.
front
().
elements
()));
return
s
;
assert
(
s
->
bytes
()
==
inputs
.
front
().
bytes
());
return
*
s
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
test/op_shape_test.cpp
View file @
0dae73fa
...
...
@@ -2208,6 +2208,119 @@ TEST_CASE(reshape_shape)
}
}
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE
(
reshape_nonstandard
)
{
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
4
,
24
,
1
,
1
,
1
},
migraphx
::
invert_permutation
({
1
,
0
,
2
,
3
,
4
}));
std
::
vector
<
std
::
pair
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
int64_t
>>>
tests
{
{{
4
,
24
},
{
1
,
0
}},
{{
4
,
24
,
1
,
1
,
1
,
1
},
{
1
,
0
,
2
,
3
,
4
,
5
}},
{{
4
,
8
,
3
,
1
,
1
},
{
2
,
0
,
1
,
3
,
4
}},
{{
4
,
1
,
3
,
4
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{{
4
,
1
,
4
,
3
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{{
4
,
2
,
4
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
2
,
12
,
1
},
{
2
,
0
,
1
,
3
}},
{{
4
,
2
,
1
,
12
},
{
3
,
0
,
1
,
2
}},
{{
4
,
4
,
2
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
8
,
1
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
8
,
3
,
1
},
{
2
,
0
,
1
,
3
}}};
for
(
const
auto
&
[
dims
,
perm
]
:
tests
)
{
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
dims
,
migraphx
::
invert_permutation
(
perm
));
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
}
TEST_CASE
(
reshape_nonstandard_squeeze
)
{
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
migraphx
::
invert_permutation
({
0
,
2
,
3
,
1
}));
std
::
vector
<
std
::
size_t
>
lens
=
{
2
,
256
,
1280
};
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
lens
,
migraphx
::
invert_permutation
({
0
,
2
,
1
}));
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
}
TEST_CASE
(
reshape_nonstandard_error
)
{
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
4
,
24
,
1
,
1
,
1
},
migraphx
::
invert_permutation
({
1
,
0
,
2
,
3
,
4
}));
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{{
4
,
8
,
3
,
2
,
2
},
{
1
},
{
4
,
8
,
4
},
{
4
,
24
,
1
,
1
,
1
,
1
,
2
},
{
8
,
4
,
4
},
{
4
,
1
,
3
,
-
1
,
-
1
},
{
4
,
3
,
0
},
{
4
,
3
,
2
},
{
3
,
0
},
{
3
,
2
}})
{
throws_shape
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
input
);
}
}
TEST_CASE
(
reshape_nonpacked_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
8
},
{
32
,
16
,
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_unsqueeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
16
},
{
64
,
32
,
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_unsqueeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
16
,
80
},
{
0
,
0
,
80
,
1
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze_error
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
std
::
vector
<
int64_t
>
new_shape
=
{
2
,
16
,
20480
};
throws_shape
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
input
);
}
TEST_CASE
(
reshape_dyn_shape
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment