"vscode:/vscode.git/clone" did not exist on "1e188f76dd66949a915b975a4619f21fdd9a75f3"
Unverified Commit b6976b94 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Make reshape copy always have standard shape as output (#2508)

parent 5fe1b075
......@@ -112,84 +112,6 @@ struct reshape
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
// can remove previous nullopts that were sent back for the alias case
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
auto n = it - start;
assert((i + n) <= istrides.size());
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
(void)d;
rstrides.push_back(stride);
}
}
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
......@@ -219,14 +141,14 @@ struct reshape
}
}
auto s = reshape_dims(inputs.front(), rdims);
auto s = shape{inputs.front().type(), rdims};
if(s->elements() != inputs.front().elements())
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return *s;
return s;
}
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
}
}
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE(reshape_nonstandard)
{
auto input = migraphx::shape::from_permutation(migraphx::shape::float_type,
{4, 24, 1, 1, 1},
migraphx::invert_permutation({1, 0, 2, 3, 4}));
std::vector<std::pair<std::vector<std::size_t>, std::vector<int64_t>>> tests{
{{4, 24}, {1, 0}},
{{4, 24, 1, 1, 1, 1}, {1, 0, 2, 3, 4, 5}},
{{4, 8, 3, 1, 1}, {2, 0, 1, 3, 4}},
{{4, 1, 3, 4, 2}, {4, 0, 1, 2, 3}},
{{4, 1, 4, 3, 2}, {4, 0, 1, 2, 3}},
{{4, 2, 4, 3}, {3, 0, 1, 2}},
{{4, 2, 12, 1}, {2, 0, 1, 3}},
{{4, 2, 1, 12}, {3, 0, 1, 2}},
{{4, 4, 2, 3}, {3, 0, 1, 2}},
{{4, 8, 1, 3}, {3, 0, 1, 2}},
{{4, 8, 3, 1}, {2, 0, 1, 3}}};
for(const auto& [dims, perm] : tests)
std::vector<std::vector<std::size_t>> tests{{4, 24},
{4, 24, 1, 1, 1, 1},
{4, 8, 3, 1, 1},
{4, 1, 3, 4, 2},
{4, 1, 4, 3, 2},
{4, 2, 4, 3},
{4, 2, 12, 1},
{4, 2, 1, 12},
{4, 4, 2, 3},
{4, 8, 1, 3},
{4, 8, 3, 1}};
for(auto dims : tests)
{
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, dims, migraphx::invert_permutation(perm));
migraphx::shape output = migraphx::shape{migraphx::shape::float_type, dims};
expect_shape(output, migraphx::make_op("reshape", {{"dims", dims}}), input);
}
}
......@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
auto input = migraphx::shape::from_permutation(
migraphx::shape::float_type, {2, 16, 16, 1280}, migraphx::invert_permutation({0, 2, 3, 1}));
std::vector<std::size_t> lens = {2, 256, 1280};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation({0, 2, 1}));
migraphx::shape output = migraphx::shape{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
}
......@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
}
}
TEST_CASE(reshape_transposed_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 4}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}, {32, 16, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}, {64, 32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_squeeze)
TEST_CASE(reshape_nonpacked_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_nonpacked_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}, {0, 0, 80, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze)
TEST_CASE(reshape_broadcast_squeeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {0, 1}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze3)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {1, 0}};
migraphx::shape output{migraphx::shape::float_type, {64}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}, {0, 0, 0, 16}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment