Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1c0b2a4a
Unverified
Commit
1c0b2a4a
authored
Jul 07, 2022
by
Charlie Lin
Committed by
GitHub
Jul 07, 2022
Browse files
Dyn shape update (#1199)
Initial sketch for changes to shape to handle dynamic dimensions
parent
bd503d89
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
390 additions
and
16 deletions
+390
-16
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+5
-0
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+6
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+71
-1
src/shape.cpp
src/shape.cpp
+194
-13
test/op_shape_test.cpp
test/op_shape_test.cpp
+2
-1
test/shape_test.cpp
test/shape_test.cpp
+112
-1
No files found.
src/include/migraphx/check_shapes.hpp
View file @
1c0b2a4a
...
...
@@ -71,6 +71,11 @@ struct check_shapes
return
end
-
begin
;
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template
<
class
...
Ts
>
const
check_shapes
&
has
(
Ts
...
ns
)
const
{
...
...
src/include/migraphx/permutation.hpp
View file @
1c0b2a4a
...
...
@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return
result
;
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/shape.hpp
View file @
1c0b2a4a
...
...
@@ -82,6 +82,23 @@ struct shape
{
};
struct
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
);
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
};
static
const
std
::
vector
<
type_t
>&
types
();
static
std
::
string
name
(
type_t
t
);
...
...
@@ -92,6 +109,12 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape
(
type_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
...
...
@@ -112,10 +135,44 @@ struct shape
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
* Return the number of elements in the tensor.
*/
std
::
size_t
elements
()
const
;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std
::
size_t
bytes
()
const
;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std
::
size_t
type_size
()
const
;
const
std
::
vector
<
dynamic_dimension
>&
dyn_dims
()
const
;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
;
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
/// Map multiple indices to space index
...
...
@@ -136,19 +193,27 @@ struct shape
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool
transposed
()
const
;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool
broadcasted
()
const
;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
@@ -252,6 +317,11 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std
::
size_t
element_space
()
const
;
private:
...
...
src/shape.cpp
View file @
1c0b2a4a
...
...
@@ -26,6 +26,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
...
...
@@ -65,13 +66,21 @@ struct shape_impl
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
m_type
(
t
),
m_dyn_dims
(
std
::
move
(
dims
))
{
}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
=
{};
std
::
vector
<
std
::
size_t
>
m_strides
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
bool
m_standard
=
false
;
std
::
vector
<
shape
::
dynamic_dimension
>
m_dyn_dims
=
{};
void
calculate_strides
()
{
m_strides
.
clear
();
...
...
@@ -87,6 +96,12 @@ struct shape_impl
std
::
size_t
element_space
()
const
{
if
(
not
m_dyn_dims
.
empty
())
{
auto
maxes
=
max_lens
();
return
std
::
accumulate
(
maxes
.
begin
(),
maxes
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
());
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -101,6 +116,11 @@ struct shape_impl
std
::
size_t
elements
()
const
{
if
(
not
m_dyn_dims
.
empty
())
{
MIGRAPHX_THROW
(
"SHAPE: elements() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -108,6 +128,35 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
vector
<
std
::
size_t
>
min_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
min
;
});
return
ret
;
}
std
::
vector
<
std
::
size_t
>
max_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
max
;
});
return
ret
;
}
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
opt
;
});
return
ret
;
}
// Does the shape skip over elements?
bool
skips
()
const
{
...
...
@@ -165,6 +214,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape
::
shape
(
type_t
t
,
std
::
initializer_list
<
std
::
size_t
>
d
)
:
shape
::
shape
(
t
,
std
::
vector
<
std
::
size_t
>
{
d
.
begin
(),
d
.
end
()})
{
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
dims
)))
{
}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
...
...
@@ -180,9 +239,13 @@ shape shape::from_permutation(type_t t,
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
{
if
(
this
->
sub_shapes
().
empty
())
...
...
@@ -199,6 +262,7 @@ std::size_t shape::bytes() const
[
&
](
auto
x
,
auto
y
)
{
return
x
+
y
.
bytes
();
});
}
}
std
::
size_t
shape
::
type_size
()
const
{
std
::
size_t
n
=
0
;
...
...
@@ -206,20 +270,35 @@ std::size_t shape::type_size() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
shape
::
index
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
l
.
begin
(),
l
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: index() called on dynamic shape"
);
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
standard
())
return
i
;
...
...
@@ -267,12 +346,20 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool
shape
::
packed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
if
(
this
->
broadcasted
())
{
// TODO: Use a filter_iterator instead
...
...
@@ -292,6 +379,10 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
any_of
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
...
...
@@ -299,6 +390,10 @@ bool shape::broadcasted() const
bool
shape
::
scalar
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
// if any stride > 0, then accumulate will return false
return
this
->
sub_shapes
().
empty
()
and
...
...
@@ -317,6 +412,10 @@ shape shape::normalize_standard() const
shape
shape
::
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: with_lens() called on dynamic shape"
);
}
assert
(
l
.
size
()
==
this
->
lens
().
size
());
auto
perm
=
find_permutation
(
*
this
);
return
shape
::
from_permutation
(
t
,
l
,
perm
);
...
...
@@ -324,6 +423,10 @@ shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
shape
shape
::
with_lens
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
{
if
(
this
->
dynamic
())
{
MIGRAPHX_THROW
(
"SHAPE: with_lens() called on dynamic shape"
);
}
return
this
->
with_lens
(
this
->
type
(),
l
);
}
...
...
@@ -338,20 +441,80 @@ std::size_t shape::element_space() const { return impl->element_space(); }
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
bool
shape
::
dynamic
()
const
{
return
not
impl
->
m_dyn_dims
.
empty
();
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
min_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
size_t
>
shape
::
max_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
max_lens
()
:
this
->
lens
();
}
std
::
vector
<
std
::
size_t
>
shape
::
opt_lens
()
const
{
return
this
->
dynamic
()
?
impl
->
opt_lens
()
:
this
->
lens
();
}
bool
shape
::
dynamic_dimension
::
is_fixed
()
const
{
return
this
->
min
==
this
->
max
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
template
<
class
Self
,
class
F
>
auto
shape
::
dynamic_dimension
::
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
x
.
opt
==
y
.
opt
);
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
os
<<
"["
<<
x
.
min
<<
", "
<<
x
.
max
<<
", "
<<
x
.
opt
<<
"]"
;
return
os
;
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
return
x
.
impl
==
y
.
impl
or
(
x
.
type
()
==
y
.
type
()
and
x
.
lens
()
==
y
.
lens
()
and
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
if
(
x
.
dynamic
()
and
y
.
dynamic
())
{
return
x
.
impl
==
y
.
impl
or
(
x
.
type
()
==
y
.
type
()
and
x
.
dyn_dims
()
==
y
.
dyn_dims
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
return
x
.
impl
==
y
.
impl
or
(
x
.
dynamic
()
==
y
.
dynamic
()
and
x
.
type
()
==
y
.
type
()
and
x
.
lens
()
==
y
.
lens
()
and
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
if
(
x
.
sub_shapes
().
empty
())
{
os
<<
x
.
type_string
()
<<
", "
;
os
<<
"{"
<<
to_string_range
(
x
.
lens
())
<<
"}, "
;
os
<<
"{"
<<
to_string_range
(
x
.
strides
())
<<
"}"
;
if
(
x
.
dynamic
())
{
os
<<
"dynamic, "
;
os
<<
x
.
type_string
()
<<
", "
;
os
<<
"{"
<<
to_string_range
(
x
.
dyn_dims
())
<<
"}"
;
}
else
{
os
<<
x
.
type_string
()
<<
", "
;
os
<<
"{"
<<
to_string_range
(
x
.
lens
())
<<
"}, "
;
os
<<
"{"
<<
to_string_range
(
x
.
strides
())
<<
"}"
;
}
}
else
{
...
...
@@ -375,12 +538,14 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
)
{
value
result
;
result
[
"type"
]
=
migraphx
::
to_value
(
s
.
type_string
());
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"sub_shapes"
]
=
migraphx
::
to_value
(
s
.
sub_shapes
());
v
=
result
;
result
[
"type"
]
=
migraphx
::
to_value
(
s
.
type_string
());
result
[
"lens"
]
=
migraphx
::
to_value
(
s
.
lens
());
result
[
"strides"
]
=
migraphx
::
to_value
(
s
.
strides
());
result
[
"sub_shapes"
]
=
migraphx
::
to_value
(
s
.
sub_shapes
());
result
[
"dynamic_dimensions"
]
=
migraphx
::
to_value
(
s
.
dyn_dims
());
v
=
result
;
}
void
migraphx_from_value
(
const
value
&
v
,
shape
&
s
)
{
auto
t
=
v
.
at
(
"type"
).
get_string
();
...
...
@@ -390,9 +555,25 @@ void migraphx_from_value(const value& v, shape& s)
}
else
{
s
=
shape
{
shape
::
parse_type
(
t
),
v
.
at
(
"lens"
).
to_vector
<
std
::
size_t
>
(),
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
()};
if
(
v
.
at
(
"dynamic_dimensions"
).
empty
())
{
s
=
shape
{
shape
::
parse_type
(
t
),
v
.
at
(
"lens"
).
to_vector
<
std
::
size_t
>
(),
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
()};
}
else
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
auto
x_opt
=
x
.
at
(
"opt"
).
template
to
<
size_t
>();
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
x_opt
};
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
}
}
}
...
...
test/op_shape_test.cpp
View file @
1c0b2a4a
...
...
@@ -981,7 +981,8 @@ TEST_CASE(multibroadcast)
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{}};
std
::
vector
<
std
::
size_t
>
empt
=
{};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
empt
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
{
...
...
test/shape_test.cpp
View file @
1c0b2a4a
...
...
@@ -38,7 +38,6 @@ TEST_CASE(test_shape_default)
EXPECT
(
s
.
elements
()
==
0
);
EXPECT
(
s
.
bytes
()
==
0
);
}
TEST_CASE
(
test_shape_assign
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
...
...
@@ -65,6 +64,118 @@ TEST_CASE(test_shape_standard)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_min_max_opt
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
},
{
6
,
3
,
1
}};
EXPECT
(
s
.
min_lens
()
==
s
.
lens
());
EXPECT
(
s
.
max_lens
()
==
s
.
lens
());
EXPECT
(
s
.
opt_lens
()
==
s
.
lens
());
}
TEST_CASE
(
test_shape_dynamic_fixed
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
2
,
2
,
0
},
{
2
,
2
,
0
},
{
3
,
3
,
0
}}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
3
);
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
EXPECT
(
s
.
min_lens
()
==
std
::
vector
<
std
::
size_t
>
{
2
,
2
,
3
});
EXPECT
(
s
.
max_lens
()
==
std
::
vector
<
std
::
size_t
>
{
2
,
2
,
3
});
EXPECT
(
s
.
opt_lens
()
==
std
::
vector
<
std
::
size_t
>
{
0
,
0
,
0
});
EXPECT
(
s
.
bytes
()
==
2
*
2
*
3
*
sizeof
(
float
));
}
TEST_CASE
(
test_shape_dynamic_not_fixed
)
{
using
migraphx
::
shape
;
std
::
vector
<
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
push_back
(
shape
::
dynamic_dimension
{
2
,
5
,
2
});
dims
.
push_back
(
shape
::
dynamic_dimension
{
2
,
8
,
0
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
dims
};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
2
);
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
EXPECT
(
s
.
min_lens
()
==
std
::
vector
<
std
::
size_t
>
{
2
,
2
});
EXPECT
(
s
.
max_lens
()
==
std
::
vector
<
std
::
size_t
>
{
5
,
8
});
EXPECT
(
s
.
opt_lens
()
==
std
::
vector
<
std
::
size_t
>
{
2
,
0
});
EXPECT
(
s
.
bytes
()
==
5
*
8
*
sizeof
(
float
));
}
TEST_CASE
(
test_shape_dynamic_compares
)
{
using
migraphx
::
shape
;
auto
a
=
shape
::
dynamic_dimension
{
2
,
5
,
2
};
auto
b
=
a
;
auto
c
=
shape
::
dynamic_dimension
{
2
,
5
,
2
};
auto
d
=
shape
::
dynamic_dimension
{
3
,
8
,
4
};
EXPECT
(
a
==
b
);
EXPECT
(
a
==
c
);
EXPECT
(
a
!=
d
);
migraphx
::
shape
s0
{
shape
::
float_type
,
{
a
,
d
}};
migraphx
::
shape
s1
=
s0
;
migraphx
::
shape
s2
{
shape
::
float_type
,
{
a
,
d
}};
migraphx
::
shape
s3
{
shape
::
int32_type
,
{
a
}};
EXPECT
(
s0
==
s1
);
EXPECT
(
s0
==
s2
);
EXPECT
(
s0
!=
s3
);
std
::
stringstream
ss0
;
std
::
stringstream
ss1
;
std
::
stringstream
ss3
;
ss0
<<
s0
;
ss1
<<
s1
;
ss3
<<
s3
;
EXPECT
(
ss0
.
str
()
==
ss1
.
str
());
EXPECT
(
ss0
.
str
()
!=
ss3
.
str
());
}
TEST_CASE
(
test_shape_dynamic_errors
)
{
using
migraphx
::
shape
;
std
::
vector
<
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
push_back
(
shape
::
dynamic_dimension
{
2
,
5
,
2
});
dims
.
push_back
(
shape
::
dynamic_dimension
{
2
,
8
,
0
});
migraphx
::
shape
s
{
shape
::
float_type
,
dims
};
EXPECT
(
test
::
throws
([
&
]
{
s
.
elements
();
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
index
({
0
,
1
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
index
(
1
);
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
index
(
std
::
vector
<
std
::
size_t
>
{
0
,
1
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
({
3
,
5
});
}));
EXPECT
(
test
::
throws
([
&
]
{
s
.
with_lens
(
shape
::
float_type
,
{
3
,
5
});
}));
}
TEST_CASE
(
test_shape_dynamic_serialize
)
{
using
migraphx
::
shape
;
std
::
vector
<
shape
::
dynamic_dimension
>
dims1
=
{};
dims1
.
push_back
(
shape
::
dynamic_dimension
{
2
,
5
,
2
});
dims1
.
push_back
(
shape
::
dynamic_dimension
{
2
,
8
,
0
});
migraphx
::
shape
s1
{
shape
::
float_type
,
dims1
};
auto
v1
=
migraphx
::
to_value
(
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
dims2
=
{};
dims2
.
push_back
(
shape
::
dynamic_dimension
{
2
,
5
,
2
});
migraphx
::
shape
s2
{
shape
::
uint64_type
,
dims2
};
auto
v2
=
migraphx
::
to_value
(
s2
);
EXPECT
(
v1
!=
v2
);
auto
s3
=
migraphx
::
from_value
<
shape
>
(
v1
);
EXPECT
(
s3
==
s1
);
auto
s4
=
migraphx
::
from_value
<
shape
>
(
v2
);
EXPECT
(
s4
==
s2
);
EXPECT
(
s3
!=
s4
);
}
TEST_CASE
(
test_shape_packed
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment