Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5c970b52
Commit
5c970b52
authored
Jul 07, 2022
by
charlie
Browse files
check_shapes object checks for allowing dynamic shapes
parent
5d236dfc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
0 deletions
+84
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+84
-0
No files found.
src/include/migraphx/check_shapes.hpp
View file @
5c970b52
...
...
@@ -38,6 +38,7 @@ struct check_shapes
const
shape
*
begin
;
const
shape
*
end
;
const
std
::
string
name
;
bool
dynamic_allowed
=
false
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
)
:
begin
(
b
),
end
(
e
),
name
(
n
)
{
...
...
@@ -54,6 +55,15 @@ struct check_shapes
{
}
~
check_shapes
()
{
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
std
::
cerr
<<
prefix
()
<<
"Dynamic shapes not supported"
<<
std
::
endl
;
std
::
abort
();
}
}
std
::
string
prefix
()
const
{
if
(
name
.
empty
())
...
...
@@ -92,6 +102,11 @@ struct check_shapes
return
*
this
;
}
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
...
...
@@ -104,6 +119,11 @@ struct check_shapes
return
*
this
;
}
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
...
...
@@ -117,6 +137,11 @@ struct check_shapes
return
*
this
;
}
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
...
...
@@ -130,6 +155,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same shape.
*/
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
...
...
@@ -137,6 +165,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same type.
*/
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
...
...
@@ -144,6 +175,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same lens.
*/
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
();
}))
...
...
@@ -151,6 +185,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same number of dimensions.
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
().
size
();
}))
...
...
@@ -158,6 +195,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are standard.
*/
const
check_shapes
&
standard
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
...
...
@@ -165,6 +205,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are standard or scalar.
*/
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
...
...
@@ -172,6 +215,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are packed.
*/
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
...
...
@@ -179,6 +225,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are packed or broadcasted.
*/
const
check_shapes
&
packed_or_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
...
...
@@ -186,6 +235,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are tuples.
*/
const
check_shapes
&
tuple_type
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
...
...
@@ -193,6 +245,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are not transposed.
*/
const
check_shapes
&
not_transposed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
...
...
@@ -200,6 +255,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes are not broadcasted.
*/
const
check_shapes
&
not_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
...
...
@@ -207,6 +265,10 @@ struct check_shapes
return
*
this
;
}
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
...
...
@@ -214,6 +276,9 @@ struct check_shapes
return
*
this
;
}
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const
check_shapes
&
batch_not_transposed
()
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
...
...
@@ -221,6 +286,15 @@ struct check_shapes
return
*
this
;
}
/*!
* Denotes that the shapes can be dynamic for the operator.
*/
const
check_shapes
&
allow_dynamic
()
{
dynamic_allowed
=
true
;
return
*
this
;
}
template
<
class
F
>
bool
same
(
F
f
)
const
{
...
...
@@ -242,6 +316,16 @@ struct check_shapes
return
std
::
all_of
(
begin
,
end
,
p
);
}
template
<
class
Predicate
>
bool
any_of
(
Predicate
p
)
const
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
const
shape
*
get
(
long
i
)
const
{
if
(
i
>=
size
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment