Commit 5b21a77f authored by Paul's avatar Paul
Browse files

Add check_shapes helper class

parent e15f5d2a
...@@ -8,6 +8,73 @@ ...@@ -8,6 +8,73 @@
namespace rtg { namespace rtg {
struct check_shapes
{
const std::vector<shape>* shapes;
check_shapes(const std::vector<shape>& s)
: shapes(&s)
{}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty()) {
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
return *this;
}
template<class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) {
return f(s) == key;
});
}
template<class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
struct not_computable struct not_computable
{ {
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
...@@ -21,17 +88,10 @@ struct convolution ...@@ -21,17 +88,10 @@ struct convolution
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.size() != 2) check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);
RTG_THROW("Wrong number of arguments");
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
if(input.type() != weights.type())
RTG_THROW("Type doesn't match");
if(input.lens().size() != weights.lens().size())
RTG_THROW("Dimensions don't match");
if(input.lens().size() != 4)
RTG_THROW("Only 4d convolution supported");
auto t = input.type(); auto t = input.type();
return {t, return {t,
{ {
...@@ -74,12 +134,9 @@ struct pooling ...@@ -74,12 +134,9 @@ struct pooling
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) check_shapes{inputs}.has(1).only_dims(4);
RTG_THROW("Wrong number of arguments");
const shape& input = inputs.at(0);
if(input.lens().size() != 4)
RTG_THROW("Only 4d pooling supported");
const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
return {t, return {t,
{ {
...@@ -117,8 +174,7 @@ struct activation ...@@ -117,8 +174,7 @@ struct activation
std::string name() const { return "activation"; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) check_shapes{inputs}.has(1);
RTG_THROW("Wrong number of arguments");
return inputs.front(); return inputs.front();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment