"stubs/vscode:/vscode.git/clone" did not exist on "35d4129f81523c279fac193cffc909bb8214acec"
Unverified Commit b75e6aae authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Merge branch 'develop' into qdq_skip_ops

parents c335de61 a60bdb67
...@@ -27,6 +27,17 @@ ...@@ -27,6 +27,17 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -112,84 +112,6 @@ struct reshape ...@@ -112,84 +112,6 @@ struct reshape
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
// can remove previous nullopts that were sent back for the alias case
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
auto n = it - start;
assert((i + n) <= istrides.size());
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
(void)d;
rstrides.push_back(stride);
}
}
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
...@@ -219,14 +141,14 @@ struct reshape ...@@ -219,14 +141,14 @@ struct reshape
} }
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = shape{inputs.front().type(), rdims};
if(s->elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return *s; return s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -110,22 +110,69 @@ struct reshape_lazy ...@@ -110,22 +110,69 @@ struct reshape_lazy
return it; return it;
} }
template <class OptionalPair>
static OptionalPair try_merge_pairs(OptionalPair p2, OptionalPair p1)
{
if(not p1.has_value())
return nullopt;
if(not p2.has_value())
return nullopt;
auto dim1 = p1->first;
auto dim2 = p2->first;
auto stride1 = p1->second;
auto stride2 = p2->second;
auto elements = dim1 * dim2;
// Transposed
if(stride2 > stride1)
return nullopt;
// Broadcasted check to avoid division by zero
if(stride2 == 0)
{
if(stride1 == 0)
return {{elements, 0}};
return nullopt;
}
if(stride1 % stride2 != 0)
return nullopt;
auto space = (stride1 * dim1 + stride2 * dim2 - stride1) / stride2;
// Nonpacked
if(space != elements)
return nullopt;
return {{elements, stride2}};
}
template <class DimIterator, class StrideIterator>
static optional<std::size_t> merge_strides(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
if(dim_start == dim_last)
return nullopt;
(void)stride_start; // Is only used in the assert
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto make_pair_optional = [&](auto dim, auto stride) {
return std::make_optional(std::make_pair(dim, stride));
};
auto dim_stride_pair =
std::inner_product(std::make_reverse_iterator(dim_last - 1),
std::make_reverse_iterator(dim_start),
std::make_reverse_iterator(stride_last - 1),
make_pair_optional(*std::prev(dim_last), *std::prev(stride_last)),
MIGRAPHX_LIFT(try_merge_pairs),
make_pair_optional);
if(not dim_stride_pair.has_value())
return nullopt;
return dim_stride_pair->second;
}
template <class DimIterator, class StrideIterator> template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last, DimIterator dim_last,
StrideIterator stride_start, StrideIterator stride_start,
StrideIterator stride_last) StrideIterator stride_last)
{ {
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last)); return merge_strides(dim_start, dim_last, stride_start, stride_last).has_value();
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
} }
// This will attempt to alias the dimensions of the input shape to the lens of // This will attempt to alias the dimensions of the input shape to the lens of
......
...@@ -26,10 +26,12 @@ ...@@ -26,10 +26,12 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \ #define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; } ->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \ [](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
......
...@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes) ...@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
} }
} }
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE(reshape_nonstandard) TEST_CASE(reshape_nonstandard)
{ {
auto input = migraphx::shape::from_permutation(migraphx::shape::float_type, auto input = migraphx::shape::from_permutation(migraphx::shape::float_type,
{4, 24, 1, 1, 1}, {4, 24, 1, 1, 1},
migraphx::invert_permutation({1, 0, 2, 3, 4})); migraphx::invert_permutation({1, 0, 2, 3, 4}));
std::vector<std::pair<std::vector<std::size_t>, std::vector<int64_t>>> tests{ std::vector<std::vector<std::size_t>> tests{{4, 24},
{{4, 24}, {1, 0}}, {4, 24, 1, 1, 1, 1},
{{4, 24, 1, 1, 1, 1}, {1, 0, 2, 3, 4, 5}}, {4, 8, 3, 1, 1},
{{4, 8, 3, 1, 1}, {2, 0, 1, 3, 4}}, {4, 1, 3, 4, 2},
{{4, 1, 3, 4, 2}, {4, 0, 1, 2, 3}}, {4, 1, 4, 3, 2},
{{4, 1, 4, 3, 2}, {4, 0, 1, 2, 3}}, {4, 2, 4, 3},
{{4, 2, 4, 3}, {3, 0, 1, 2}}, {4, 2, 12, 1},
{{4, 2, 12, 1}, {2, 0, 1, 3}}, {4, 2, 1, 12},
{{4, 2, 1, 12}, {3, 0, 1, 2}}, {4, 4, 2, 3},
{{4, 4, 2, 3}, {3, 0, 1, 2}}, {4, 8, 1, 3},
{{4, 8, 1, 3}, {3, 0, 1, 2}}, {4, 8, 3, 1}};
{{4, 8, 3, 1}, {2, 0, 1, 3}}};
for(auto dims : tests)
for(const auto& [dims, perm] : tests)
{ {
migraphx::shape output = migraphx::shape::from_permutation( migraphx::shape output = migraphx::shape{migraphx::shape::float_type, dims};
migraphx::shape::float_type, dims, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape", {{"dims", dims}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", dims}}), input);
} }
} }
...@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze) ...@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
auto input = migraphx::shape::from_permutation( auto input = migraphx::shape::from_permutation(
migraphx::shape::float_type, {2, 16, 16, 1280}, migraphx::invert_permutation({0, 2, 3, 1})); migraphx::shape::float_type, {2, 16, 16, 1280}, migraphx::invert_permutation({0, 2, 3, 1}));
std::vector<std::size_t> lens = {2, 256, 1280}; std::vector<std::size_t> lens = {2, 256, 1280};
migraphx::shape output = migraphx::shape::from_permutation( migraphx::shape output = migraphx::shape{migraphx::shape::float_type, lens};
migraphx::shape::float_type, lens, migraphx::invert_permutation({0, 2, 1}));
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
} }
...@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error) ...@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
} }
} }
TEST_CASE(reshape_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_unsqueeze1) TEST_CASE(reshape_nonpacked_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}, {32, 16, 2}}; migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_nonpacked_unsqueeze2) TEST_CASE(reshape_nonpacked_unsqueeze2)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}, {64, 32, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_nonpacked_squeeze) TEST_CASE(reshape_nonpacked_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}}; migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_broadcast_unsqueeze1) TEST_CASE(reshape_broadcast_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_broadcast_unsqueeze2) TEST_CASE(reshape_broadcast_unsqueeze2)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}, {0, 0, 80, 1}}; migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_broadcast_squeeze) TEST_CASE(reshape_broadcast_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_broadcast_squeeze_memlayout_change) TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}, {0, 0, 0, 16}}; migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
...@@ -2960,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error) ...@@ -2960,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
} }
} }
TEST_CASE(reshape_lazy_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze1) TEST_CASE(reshape_lazy_nonpacked_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
...@@ -2974,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2) ...@@ -2974,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_nonpacked_squeeze) TEST_CASE(reshape_lazy_nonpacked_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}}; migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}}; migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze1) TEST_CASE(reshape_lazy_broadcast_unsqueeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
...@@ -2995,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2) ...@@ -2995,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_broadcast_squeeze) TEST_CASE(reshape_lazy_broadcast_squeeze1)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}}; migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_lazy_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", {64}}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze_error) TEST_CASE(reshape_lazy_broadcast_squeeze_error)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment