Unverified Commit a60bdb67 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Lazy reshape fixes (#2505)

parent b6976b94
...@@ -27,6 +27,17 @@ ...@@ -27,6 +27,17 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -110,22 +110,69 @@ struct reshape_lazy ...@@ -110,22 +110,69 @@ struct reshape_lazy
return it; return it;
} }
template <class OptionalPair>
static OptionalPair try_merge_pairs(OptionalPair p2, OptionalPair p1)
{
if(not p1.has_value())
return nullopt;
if(not p2.has_value())
return nullopt;
auto dim1 = p1->first;
auto dim2 = p2->first;
auto stride1 = p1->second;
auto stride2 = p2->second;
auto elements = dim1 * dim2;
// Transposed
if(stride2 > stride1)
return nullopt;
// Broadcasted check to avoid division by zero
if(stride2 == 0)
{
if(stride1 == 0)
return {{elements, 0}};
return nullopt;
}
if(stride1 % stride2 != 0)
return nullopt;
auto space = (stride1 * dim1 + stride2 * dim2 - stride1) / stride2;
// Nonpacked
if(space != elements)
return nullopt;
return {{elements, stride2}};
}
template <class DimIterator, class StrideIterator> template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, static optional<std::size_t> merge_strides(DimIterator dim_start,
DimIterator dim_last, DimIterator dim_last,
StrideIterator stride_start, StrideIterator stride_start,
StrideIterator stride_last) StrideIterator stride_last)
{ {
if(dim_start == dim_last)
return nullopt;
(void)stride_start; // Is only used in the assert
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last)); assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last); auto make_pair_optional = [&](auto dim, auto stride) {
return std::equal(std::make_reverse_iterator(dim_last), return std::make_optional(std::make_pair(dim, stride));
std::make_reverse_iterator(dim_start + 1), };
auto dim_stride_pair =
std::inner_product(std::make_reverse_iterator(dim_last - 1),
std::make_reverse_iterator(dim_start),
std::make_reverse_iterator(stride_last - 1), std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start), make_pair_optional(*std::prev(dim_last), *std::prev(stride_last)),
[&](auto dim, auto stride) { MIGRAPHX_LIFT(try_merge_pairs),
cstride *= dim; make_pair_optional);
return stride == cstride; if(not dim_stride_pair.has_value())
}); return nullopt;
return dim_stride_pair->second;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
return merge_strides(dim_start, dim_last, stride_start, stride_last).has_value();
} }
// This will attempt to alias the dimensions of the input shape to the lens of // This will attempt to alias the dimensions of the input shape to the lens of
......
...@@ -26,10 +26,12 @@ ...@@ -26,10 +26,12 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \ #define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; } ->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \ [](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
......
...@@ -2977,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error) ...@@ -2977,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
} }
} }
TEST_CASE(reshape_lazy_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze1) TEST_CASE(reshape_lazy_nonpacked_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
...@@ -2991,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2) ...@@ -2991,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_nonpacked_squeeze) TEST_CASE(reshape_lazy_nonpacked_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}}; migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze1) TEST_CASE(reshape_lazy_broadcast_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
...@@ -3012,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2) ...@@ -3012,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_broadcast_squeeze) TEST_CASE(reshape_lazy_broadcast_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze_error) TEST_CASE(reshape_lazy_broadcast_squeeze_error)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment