Unverified Commit c51cf531 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic check_shapes (#1295)

Dynamic shape handling in shape object
parent 4d59b7c7
...@@ -38,20 +38,34 @@ struct check_shapes ...@@ -38,20 +38,34 @@ struct check_shapes
const shape* begin; const shape* begin;
const shape* end; const shape* end;
const std::string name; const std::string name;
const bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n) check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d)
{ {
check_dynamic();
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name()) check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic();
} }
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()) : begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic();
}
void check_dynamic() const
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
{
MIGRAPHX_THROW(prefix() + "Dynamic shapes not supported");
}
} }
std::string prefix() const std::string prefix() const
...@@ -92,44 +106,62 @@ struct check_shapes ...@@ -92,44 +106,62 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() != n) if(begin->max_lens().size() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
} }
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() > n) if(begin->max_lens().size() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
return *this; return *this;
} }
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->lens().size() < n) if(begin->max_lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
return *this; return *this;
} }
/*!
* Check all shapes have the same shape.
*/
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
...@@ -137,6 +169,9 @@ struct check_shapes ...@@ -137,6 +169,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same type.
*/
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
...@@ -144,20 +179,32 @@ struct check_shapes ...@@ -144,20 +179,32 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same lens.
*/
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this; return *this;
} }
/*!
* Check all shapes have the same number of dimensions.
*/
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
/*!
* Check all shapes are standard.
*/
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(!this->all_of([](const shape& s) { return s.standard(); }))
...@@ -165,6 +212,9 @@ struct check_shapes ...@@ -165,6 +212,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are standard or scalar.
*/
const check_shapes& standard_or_scalar() const const check_shapes& standard_or_scalar() const
{ {
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); })) if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
...@@ -172,6 +222,9 @@ struct check_shapes ...@@ -172,6 +222,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed.
*/
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(!this->all_of([](const shape& s) { return s.packed(); }))
...@@ -179,6 +232,9 @@ struct check_shapes ...@@ -179,6 +232,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed or broadcasted.
*/
const check_shapes& packed_or_broadcasted() const const check_shapes& packed_or_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
...@@ -186,6 +242,9 @@ struct check_shapes ...@@ -186,6 +242,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are tuples.
*/
const check_shapes& tuple_type() const const check_shapes& tuple_type() const
{ {
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
...@@ -193,6 +252,9 @@ struct check_shapes ...@@ -193,6 +252,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not transposed.
*/
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
...@@ -200,6 +262,9 @@ struct check_shapes ...@@ -200,6 +262,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not broadcasted.
*/
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
...@@ -207,6 +272,10 @@ struct check_shapes ...@@ -207,6 +272,10 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const check_shapes& elements(std::size_t n) const const check_shapes& elements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
...@@ -214,6 +283,9 @@ struct check_shapes ...@@ -214,6 +283,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const check_shapes& batch_not_transposed() const const check_shapes& batch_not_transposed() const
{ {
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
...@@ -242,6 +314,16 @@ struct check_shapes ...@@ -242,6 +314,16 @@ struct check_shapes
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
template <class Predicate>
bool any_of(Predicate p) const
{
if(begin == end)
return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const const shape* get(long i) const
{ {
if(i >= size()) if(i >= size())
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "test.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
/*!
* Tests for check_shapes object handling dynamic shapes
*/
using migraphx::shape;
bool create_shapes(bool dynamic_allowed)
{
try
{
shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6, 0}, {4, 4, 0}}};
auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true;
}
catch(...)
{
return false;
}
}
TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); }
TEST_CASE(fail_dynamic_shape) { EXPECT(!create_shapes(false)); }
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment