"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f717f6bf7ef4f9b6e95932744ac89e07ed0f0c52"
Commit a48f046e authored by Paul's avatar Paul
Browse files

Move check_shapes to a seperate class

parent 415476ae
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp>
#include <algorithm>
namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
} // namespace migraph
#endif
...@@ -3,98 +3,13 @@ ...@@ -3,98 +3,13 @@
#include <array> #include <array>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/check_shapes.hpp>
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp> #include <migraph/streamutils.hpp>
#include <cmath> #include <cmath>
namespace migraph { namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
struct not_computable struct not_computable
{ {
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape, std::vector<argument>) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment