Commit 5b21a77f authored by Paul's avatar Paul
Browse files

Add check_shapes helper class

parent e15f5d2a
......@@ -8,6 +8,73 @@
namespace rtg {
struct check_shapes
{
const std::vector<shape>* shapes;
check_shapes(const std::vector<shape>& s)
: shapes(&s)
{}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty()) {
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
return *this;
}
template<class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) {
return f(s) == key;
});
}
template<class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
struct not_computable
{
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......@@ -21,17 +88,10 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() != 2)
RTG_THROW("Wrong number of arguments");
check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
if(input.type() != weights.type())
RTG_THROW("Type doesn't match");
if(input.lens().size() != weights.lens().size())
RTG_THROW("Dimensions don't match");
if(input.lens().size() != 4)
RTG_THROW("Only 4d convolution supported");
auto t = input.type();
return {t,
{
......@@ -74,12 +134,9 @@ struct pooling
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
RTG_THROW("Wrong number of arguments");
const shape& input = inputs.at(0);
if(input.lens().size() != 4)
RTG_THROW("Only 4d pooling supported");
check_shapes{inputs}.has(1).only_dims(4);
const shape& input = inputs.at(0);
auto t = input.type();
return {t,
{
......@@ -117,8 +174,7 @@ struct activation
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
RTG_THROW("Wrong number of arguments");
check_shapes{inputs}.has(1);
return inputs.front();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment