Commit 0f3dcb50 authored by Paul's avatar Paul
Browse files

Formatting

parent 5b21a77f
...@@ -12,22 +12,22 @@ struct check_shapes ...@@ -12,22 +12,22 @@ struct check_shapes
{ {
const std::vector<shape>* shapes; const std::vector<shape>* shapes;
check_shapes(const std::vector<shape>& s) check_shapes(const std::vector<shape>& s) : shapes(&s) {}
: shapes(&s)
{}
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->size() != n) if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + std::to_string(shapes->size())); RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
return *this; return *this;
} }
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(!shapes->empty()) { if(!shapes->empty())
{
if(shapes->front().lens().size() != n) if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported"); RTG_THROW("Only " + std::to_string(n) + "d supported");
} }
...@@ -55,19 +55,17 @@ struct check_shapes ...@@ -55,19 +55,17 @@ struct check_shapes
return *this; return *this;
} }
template<class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->empty()) if(shapes->empty())
return true; return true;
auto&& key = f(shapes->front()); auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return this->all_of([&](const shape& s) { return f(s) == key; });
return f(s) == key;
});
} }
template<class Predicate> template <class Predicate>
bool all_of(Predicate p) const bool all_of(Predicate p) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
...@@ -92,7 +90,7 @@ struct convolution ...@@ -92,7 +90,7 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -137,7 +135,7 @@ struct pooling ...@@ -137,7 +135,7 @@ struct pooling
check_shapes{inputs}.has(1).only_dims(4); check_shapes{inputs}.has(1).only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment