Commit 51b79c51 authored by Paul's avatar Paul
Browse files

Format

parent e9e7f8c9
......@@ -96,33 +96,40 @@ struct reshape
return {s0.type(), output_dyn_dims};
}
template<class Iterator>
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if (x != dim)
if(x != dim)
return start;
return it;
}
template<class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, DimIterator dim_last, StrideIterator stride_start, StrideIterator stride_last)
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last), std::make_reverse_iterator(dim_start+1), std::make_reverse_iterator(stride_last-1), std::make_reverse_iterator(stride_start), [&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
auto&& idims = inputs.front().lens();
auto&& istrides = inputs.front().strides();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
......@@ -150,7 +157,7 @@ struct reshape
}
shape s;
if (inputs.front().standard())
if(inputs.front().standard())
{
s = shape{inputs.front().type(), rdims};
}
......@@ -163,21 +170,22 @@ struct reshape
{
auto idim = idims[i];
auto rdim = rdims[r];
if (rdim == idim)
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin()+i;
auto it = compute_end_dim(start, idims.end(), rdim);
if (it == start)
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
break;
auto n = it - start;
if ((i+n) > istrides.size())
if((i + n) > istrides.size())
break;
if (not can_strides_merge(start, it+1, istrides.begin()+i, istrides.begin()+i+n))
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n))
break;
i += n;
rstrides.push_back(istrides[i]);
......@@ -185,15 +193,15 @@ struct reshape
// unsqueeze
else if(rdim < idim)
{
auto start = rdims.begin()+i;
auto it = compute_end_dim(start, rdims.end(), idim);
if (it == start)
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
break;
auto n = it - start;
if ((r+n) > rdims.size())
if((r + n) > rdims.size())
break;
auto stride = istrides[i] * idim;
std::for_each(start, it+1, [&](auto dim) {
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
......@@ -204,22 +212,21 @@ struct reshape
}
// Handle trailing 1s
if (rstrides.size() < rdims.size() and not rstrides.empty())
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d:range(rdims.begin()+rstrides.size(), rdims.end()))
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if (d != 1)
if(d != 1)
break;
rstrides.push_back(stride);
}
}
if (rdims.size() != rstrides.size())
if(rdims.size() != rstrides.size())
MIGRAPHX_THROW("Reshape on axis that is not standard");
s = shape{inputs.front().type(), rdims, rstrides};
}
assert(s.bytes() == inputs.front().bytes());
......
......@@ -2133,8 +2133,9 @@ TEST_CASE(reshape_nonstandard_unsqeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}};
std::vector<std::size_t> lens = {4, 1, 3, 4, 2};
std::vector<int64_t> perm = {4, 0, 1, 2, 3};
migraphx::shape output = migraphx::shape::from_permutation(migraphx::shape::float_type, lens, migraphx::invert_permutation(perm));
std::vector<int64_t> perm = {4, 0, 1, 2, 3};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment