Commit 51b79c51 authored by Paul's avatar Paul
Browse files

Format

parent e9e7f8c9
...@@ -96,7 +96,7 @@ struct reshape ...@@ -96,7 +96,7 @@ struct reshape
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
template<class Iterator> template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim) static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{ {
std::size_t x = 1; std::size_t x = 1;
...@@ -104,16 +104,23 @@ struct reshape ...@@ -104,16 +104,23 @@ struct reshape
x *= i; x *= i;
return x >= dim; return x >= dim;
}); });
if (x != dim) if(x != dim)
return start; return start;
return it; return it;
} }
template<class DimIterator, class StrideIterator> template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, DimIterator dim_last, StrideIterator stride_start, StrideIterator stride_last) static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{ {
auto cstride = *std::prev(stride_last); auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last), std::make_reverse_iterator(dim_start+1), std::make_reverse_iterator(stride_last-1), std::make_reverse_iterator(stride_start), [&](auto dim, auto stride) { return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim; cstride *= dim;
return stride == cstride; return stride == cstride;
}); });
...@@ -150,7 +157,7 @@ struct reshape ...@@ -150,7 +157,7 @@ struct reshape
} }
shape s; shape s;
if (inputs.front().standard()) if(inputs.front().standard())
{ {
s = shape{inputs.front().type(), rdims}; s = shape{inputs.front().type(), rdims};
} }
...@@ -163,21 +170,22 @@ struct reshape ...@@ -163,21 +170,22 @@ struct reshape
{ {
auto idim = idims[i]; auto idim = idims[i];
auto rdim = rdims[r]; auto rdim = rdims[r];
if (rdim == idim) if(rdim == idim)
{ {
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
// squeeze // squeeze
else if(rdim > idim) else if(rdim > idim)
{ {
auto start = idims.begin()+i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if (it == start) if(it == start)
break; break;
auto n = it - start; auto n = it - start;
if ((i+n) > istrides.size()) if((i + n) > istrides.size())
break; break;
if (not can_strides_merge(start, it+1, istrides.begin()+i, istrides.begin()+i+n)) if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n))
break; break;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
...@@ -185,15 +193,15 @@ struct reshape ...@@ -185,15 +193,15 @@ struct reshape
// unsqueeze // unsqueeze
else if(rdim < idim) else if(rdim < idim)
{ {
auto start = rdims.begin()+i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if (it == start) if(it == start)
break; break;
auto n = it - start; auto n = it - start;
if ((r+n) > rdims.size()) if((r + n) > rdims.size())
break; break;
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
std::for_each(start, it+1, [&](auto dim) { std::for_each(start, it + 1, [&](auto dim) {
stride /= dim; stride /= dim;
rstrides.push_back(stride); rstrides.push_back(stride);
}); });
...@@ -204,22 +212,21 @@ struct reshape ...@@ -204,22 +212,21 @@ struct reshape
} }
// Handle trailing 1s // Handle trailing 1s
if (rstrides.size() < rdims.size() and not rstrides.empty()) if(rstrides.size() < rdims.size() and not rstrides.empty())
{ {
auto stride = rstrides.back(); auto stride = rstrides.back();
for(auto d:range(rdims.begin()+rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if (d != 1) if(d != 1)
break; break;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if (rdims.size() != rstrides.size()) if(rdims.size() != rstrides.size())
MIGRAPHX_THROW("Reshape on axis that is not standard"); MIGRAPHX_THROW("Reshape on axis that is not standard");
s = shape{inputs.front().type(), rdims, rstrides}; s = shape{inputs.front().type(), rdims, rstrides};
} }
assert(s.bytes() == inputs.front().bytes()); assert(s.bytes() == inputs.front().bytes());
......
...@@ -2134,7 +2134,8 @@ TEST_CASE(reshape_nonstandard_unsqeeze) ...@@ -2134,7 +2134,8 @@ TEST_CASE(reshape_nonstandard_unsqeeze)
migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}};
std::vector<std::size_t> lens = {4, 1, 3, 4, 2}; std::vector<std::size_t> lens = {4, 1, 3, 4, 2};
std::vector<int64_t> perm = {4, 0, 1, 2, 3}; std::vector<int64_t> perm = {4, 0, 1, 2, 3};
migraphx::shape output = migraphx::shape::from_permutation(migraphx::shape::float_type, lens, migraphx::invert_permutation(perm)); migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment