Commit 5c970b52 authored by charlie's avatar charlie
Browse files

check_shapes object checks for allowing dynamic shapes

parent 5d236dfc
......@@ -38,6 +38,7 @@ struct check_shapes
const shape* begin;
const shape* end;
const std::string name;
bool dynamic_allowed = false;
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
{
......@@ -54,6 +55,15 @@ struct check_shapes
{
}
~check_shapes()
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
{
std::cerr << prefix() << "Dynamic shapes not supported" << std::endl;
std::abort();
}
}
std::string prefix() const
{
if(name.empty())
......@@ -92,6 +102,11 @@ struct check_shapes
return *this;
}
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& only_dims(std::size_t n) const
{
assert(begin != nullptr);
......@@ -104,6 +119,11 @@ struct check_shapes
return *this;
}
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& max_ndims(std::size_t n) const
{
assert(begin != nullptr);
......@@ -117,6 +137,11 @@ struct check_shapes
return *this;
}
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& min_ndims(std::size_t n) const
{
assert(begin != nullptr);
......@@ -130,6 +155,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same shape.
*/
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
......@@ -137,6 +165,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same type.
*/
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
......@@ -144,6 +175,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same lens.
*/
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
......@@ -151,6 +185,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same number of dimensions.
*/
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
......@@ -158,6 +195,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are standard.
*/
const check_shapes& standard() const
{
if(!this->all_of([](const shape& s) { return s.standard(); }))
......@@ -165,6 +205,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are standard or scalar.
*/
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
......@@ -172,6 +215,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed.
*/
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
......@@ -179,6 +225,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed or broadcasted.
*/
const check_shapes& packed_or_broadcasted() const
{
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
......@@ -186,6 +235,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are tuples.
*/
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
......@@ -193,6 +245,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are not transposed.
*/
const check_shapes& not_transposed() const
{
if(!this->all_of([](const shape& s) { return not s.transposed(); }))
......@@ -200,6 +255,9 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are not broadcasted.
*/
const check_shapes& not_broadcasted() const
{
if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
......@@ -207,6 +265,10 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
......@@ -214,6 +276,9 @@ struct check_shapes
return *this;
}
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const check_shapes& batch_not_transposed() const
{
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
......@@ -221,6 +286,15 @@ struct check_shapes
return *this;
}
/*!
* Denotes that the shapes can be dynamic for the operator.
*/
const check_shapes& allow_dynamic()
{
dynamic_allowed = true;
return *this;
}
template <class F>
bool same(F f) const
{
......@@ -242,6 +316,16 @@ struct check_shapes
return std::all_of(begin, end, p);
}
template <class Predicate>
bool any_of(Predicate p) const
{
if(begin == end)
return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const
{
if(i >= size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment