Commit 2097978e authored by Paul's avatar Paul
Browse files

Improve checking of shapes

parent cfe1010f
......@@ -18,6 +18,12 @@ struct check_shapes
{
}
template<class Op>
check_shapes(const shape* b, const shape* e, const Op& op)
: begin(b), end(e), name(op.name())
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op>
......@@ -119,6 +125,13 @@ struct check_shapes
return *this;
}
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this;
}
template <class F>
bool same(F f) const
{
......
......@@ -56,6 +56,8 @@ struct batch_norm_inference
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data()+1, *this}.only_dims(4);
check_shapes{inputs.data()+1, inputs.data()+inputs.size(), *this}.same_shape().elements(inputs.front().lens()[1]);
return inputs.front();
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment