Commit b6099eeb authored by charlie's avatar charlie
Browse files

Add more dynamic shape support

parent fdf09748
...@@ -117,7 +117,7 @@ struct check_shapes ...@@ -117,7 +117,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() != n) if(begin->max_lens().size() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -134,7 +134,7 @@ struct check_shapes ...@@ -134,7 +134,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() > n) if(begin->max_lens().size() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -152,7 +152,7 @@ struct check_shapes ...@@ -152,7 +152,7 @@ struct check_shapes
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() < n) if(begin->max_lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -184,8 +184,11 @@ struct check_shapes ...@@ -184,8 +184,11 @@ struct check_shapes
*/ */
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this; return *this;
} }
...@@ -194,7 +197,7 @@ struct check_shapes ...@@ -194,7 +197,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment