Commit bb13878f authored by Scott Thornton's avatar Scott Thornton
Browse files

Added same_ndims() in check_shapes

parent 81b0631e
......@@ -56,6 +56,13 @@ struct check_shapes
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
......@@ -87,7 +94,7 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment