Commit 3a1980fc authored by Paul's avatar Paul
Browse files

Check reshape sizes

parent 49e7cd87
......@@ -287,8 +287,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
RTG_THROW("Wrong number of arguments");
check_shapes{inputs}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
......@@ -301,7 +300,10 @@ struct reshape
rdims.pop_back();
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
}
return {inputs.front().type(), rdims};
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
RTG_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment