Commit 47371cfd authored by Paul's avatar Paul
Browse files

Merge branch 'simplify-more-reshapes' into sd-opt

parents 359bb1cd 6690765c
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -115,6 +116,7 @@ struct reshape ...@@ -115,6 +116,7 @@ struct reshape
StrideIterator stride_start, StrideIterator stride_start,
StrideIterator stride_last) StrideIterator stride_last)
{ {
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last); auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last), return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1), std::make_reverse_iterator(dim_start + 1),
...@@ -126,43 +128,14 @@ struct reshape ...@@ -126,43 +128,14 @@ struct reshape
}); });
} }
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
auto&& istrides = inputs.front().strides();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(dims[i] == 0) if(input.standard())
rdims[i] = idims[i]; return shape{input.type(), rdims};
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(n_neg_dims > 0) const auto& idims = input.lens();
{ const auto& istrides = input.strides();
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
shape s;
if(inputs.front().standard())
{
s = shape{inputs.front().type(), rdims};
}
else
{
std::vector<std::size_t> rstrides; std::vector<std::size_t> rstrides;
std::size_t i = 0; std::size_t i = 0;
std::size_t r = 0; std::size_t r = 0;
...@@ -180,26 +153,26 @@ struct reshape ...@@ -180,26 +153,26 @@ struct reshape
auto start = idims.begin() + i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start) if(it == start)
break; return nullopt;
auto n = it - start; auto n = it - start;
if((i + n) > istrides.size()) if((i + n) > istrides.size())
break; return nullopt;
if(not can_strides_merge( if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n)) start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
break; return nullopt;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
// unsqueeze // unsqueeze
else if(rdim < idim) else // if(rdim < idim)
{ {
auto start = rdims.begin() + i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start) if(it == start)
break; return nullopt;
auto n = it - start; auto n = it - start;
if((r + n) > rdims.size()) if((r + n) > rdims.size())
break; return nullopt;
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) { std::for_each(start, it + 1, [&](auto dim) {
stride /= dim; stride /= dim;
...@@ -218,24 +191,58 @@ struct reshape ...@@ -218,24 +191,58 @@ struct reshape
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if(d != 1) if(d != 1)
break; return nullopt;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if(rdims.size() != rstrides.size()) if(rdims.size() != rstrides.size())
MIGRAPHX_THROW("Reshape on axis that is not standard"); return nullopt;
s = shape{inputs.front().type(), rdims, rstrides}; return shape{input.type(), rdims, rstrides};
} }
assert(s.bytes() == inputs.front().bytes()); shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
// auto&& istrides = inputs.front().strides();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(s.elements() != inputs.front().elements()) if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return s;
assert(s->bytes() == inputs.front().bytes());
return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -2161,6 +2161,8 @@ TEST_CASE(reshape_shape) ...@@ -2161,6 +2161,8 @@ TEST_CASE(reshape_shape)
for(auto&& new_shape : for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}}) std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}})
{ {
std::cout << "input: " << input << std::endl;
std::cout << "dims: " << migraphx::to_string_range(new_shape) << std::endl;
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
} }
...@@ -2181,7 +2183,7 @@ TEST_CASE(reshape_shape) ...@@ -2181,7 +2183,7 @@ TEST_CASE(reshape_shape)
} }
} }
TEST_CASE(reshape_nonstandard_unsqeeze) TEST_CASE(reshape_nonstandard_unsqueeze)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}};
std::vector<std::size_t> lens = {4, 1, 3, 4, 2}; std::vector<std::size_t> lens = {4, 1, 3, 4, 2};
...@@ -2191,6 +2193,23 @@ TEST_CASE(reshape_nonstandard_unsqeeze) ...@@ -2191,6 +2193,23 @@ TEST_CASE(reshape_nonstandard_unsqeeze)
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
} }
TEST_CASE(reshape_nonstandard_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {327680, 16, 1, 256}};
std::vector<std::size_t> lens = {2, 256, 1280};
std::vector<int64_t> perm = {0, 2, 1};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
}
TEST_CASE(reshape_broadcast_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_dyn_shape) TEST_CASE(reshape_dyn_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment