Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a60bdb67
Unverified
Commit
a60bdb67
authored
Dec 13, 2023
by
Paul Fultz II
Committed by
GitHub
Dec 13, 2023
Browse files
Lazy reshape fixes (#2505)
parent
b6976b94
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
12 deletions
+96
-12
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+11
-0
src/include/migraphx/op/reshape_lazy.hpp
src/include/migraphx/op/reshape_lazy.hpp
+57
-10
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+2
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+26
-2
No files found.
src/include/migraphx/functional.hpp
View file @
a60bdb67
...
...
@@ -27,6 +27,17 @@
#include <utility>
#include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/reshape_lazy.hpp
View file @
a60bdb67
...
...
@@ -110,22 +110,69 @@ struct reshape_lazy
return
it
;
}
template
<
class
OptionalPair
>
static
OptionalPair
try_merge_pairs
(
OptionalPair
p2
,
OptionalPair
p1
)
{
if
(
not
p1
.
has_value
())
return
nullopt
;
if
(
not
p2
.
has_value
())
return
nullopt
;
auto
dim1
=
p1
->
first
;
auto
dim2
=
p2
->
first
;
auto
stride1
=
p1
->
second
;
auto
stride2
=
p2
->
second
;
auto
elements
=
dim1
*
dim2
;
// Transposed
if
(
stride2
>
stride1
)
return
nullopt
;
// Broadcasted check to avoid division by zero
if
(
stride2
==
0
)
{
if
(
stride1
==
0
)
return
{{
elements
,
0
}};
return
nullopt
;
}
if
(
stride1
%
stride2
!=
0
)
return
nullopt
;
auto
space
=
(
stride1
*
dim1
+
stride2
*
dim2
-
stride1
)
/
stride2
;
// Nonpacked
if
(
space
!=
elements
)
return
nullopt
;
return
{{
elements
,
stride2
}};
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
optional
<
std
::
size_t
>
merge_strides
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
if
(
dim_start
==
dim_last
)
return
nullopt
;
(
void
)
stride_start
;
// Is only used in the assert
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
make_pair_optional
=
[
&
](
auto
dim
,
auto
stride
)
{
return
std
::
make_optional
(
std
::
make_pair
(
dim
,
stride
));
};
auto
dim_stride_pair
=
std
::
inner_product
(
std
::
make_reverse_iterator
(
dim_last
-
1
),
std
::
make_reverse_iterator
(
dim_start
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
make_pair_optional
(
*
std
::
prev
(
dim_last
),
*
std
::
prev
(
stride_last
)),
MIGRAPHX_LIFT
(
try_merge_pairs
),
make_pair_optional
);
if
(
not
dim_stride_pair
.
has_value
())
return
nullopt
;
return
dim_stride_pair
->
second
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
return
merge_strides
(
dim_start
,
dim_last
,
stride_start
,
stride_last
).
has_value
();
}
// This will attempt to alias the dimensions of the input shape to the lens of
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
a60bdb67
...
...
@@ -26,10 +26,12 @@
#include <migraphx/kernels/integral_constant.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
...
...
test/op_shape_test.cpp
View file @
a60bdb67
...
...
@@ -2977,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
}
}
TEST_CASE(reshape_lazy_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
...
...
@@ -2991,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE
(
reshape_lazy_nonpacked_squeeze
)
TEST_CASE(reshape_lazy_nonpacked_squeeze
1
)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
...
...
@@ -3012,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze
)
TEST_CASE(reshape_lazy_broadcast_squeeze
1
)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment