"include/vscode:/vscode.git/clone" did not exist on "075abf92cceb660722e7ba7993c3d192896c35d7"
Commit 22b88c37 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Simplify reshape now that we assume reshape_lazy to perform aliasing

Case here is to let reshape fall into a contiguous to do the copy required
for nonstandard shape->standard shape convert.
parent c7b2fd1c
...@@ -97,19 +97,6 @@ struct reshape ...@@ -97,19 +97,6 @@ struct reshape
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator> template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last, DimIterator dim_last,
...@@ -128,81 +115,6 @@ struct reshape ...@@ -128,81 +115,6 @@ struct reshape
}); });
} }
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
...@@ -232,17 +144,7 @@ struct reshape ...@@ -232,17 +144,7 @@ struct reshape
} }
} }
auto s = reshape_dims(inputs.front(), rdims); return shape{inputs.front().type(), rdims};
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment