Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c0e18e78
Commit
c0e18e78
authored
May 03, 2022
by
charlie
Browse files
Dynamic shape handling in shape object
parent
764273e4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
117 additions
and
1 deletion
+117
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+5
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+21
-0
src/permutation.cpp
src/permutation.cpp
+7
-0
src/shape.cpp
src/shape.cpp
+50
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+1
-1
test/shape_test.cpp
test/shape_test.cpp
+33
-0
No files found.
src/include/migraphx/check_shapes.hpp
View file @
c0e18e78
...
...
@@ -48,6 +48,11 @@ struct check_shapes
return
end
-
begin
;
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template
<
class
...
Ts
>
const
check_shapes
&
has
(
Ts
...
ns
)
const
{
...
...
src/include/migraphx/shape.hpp
View file @
c0e18e78
...
...
@@ -59,6 +59,15 @@ struct shape
{
};
struct
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
bool
is_fixed
()
const
{
return
min
==
max
;
};
bool
has_optimal
()
const
{
return
opt
!=
0
;
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
std
::
string
name
(
type_t
t
);
...
...
@@ -69,6 +78,8 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
...
...
@@ -93,6 +104,8 @@ struct shape
std
::
size_t
bytes
()
const
;
std
::
size_t
type_size
()
const
;
const
std
::
vector
<
dynamic_dimension
>&
dyn_dims
()
const
;
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
/// Map multiple indices to space index
...
...
@@ -115,17 +128,24 @@ struct shape
/// Returns true if the shape is packed with no padding
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool
transposed
()
const
;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool
broadcasted
()
const
;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
@@ -225,6 +245,7 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
/// size of the data buffer
std
::
size_t
element_space
()
const
;
private:
...
...
src/permutation.cpp
View file @
c0e18e78
...
...
@@ -13,11 +13,18 @@ shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
return
{
s
.
type
(),
reorder_dims
(
s
.
lens
(),
permutation
),
reorder_dims
(
s
.
strides
(),
permutation
)};
}
/*!
* Inverts the permutation using the less_than operator
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
)
{
return
sort_permutation
(
permutation
,
std
::
less
<>
{});
}
/*!
* Computes a permutation for the lengths based on decesending stride order.
* Compares the lengths if the strides are the same.
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
)
{
std
::
vector
<
std
::
int64_t
>
result
(
s
.
lens
().
size
());
...
...
src/shape.cpp
View file @
c0e18e78
...
...
@@ -3,6 +3,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
...
...
@@ -45,11 +46,20 @@ struct shape_impl
}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
m_type
(
t
),
m_dynamic
(
true
),
m_dyn_dims
(
std
::
move
(
dims
))
{
}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
=
{};
std
::
vector
<
std
::
size_t
>
m_strides
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
bool
m_standard
=
false
;
bool
m_dynamic
=
false
;
std
::
vector
<
shape
::
dynamic_dimension
>
m_dyn_dims
=
{};
void
calculate_strides
()
{
...
...
@@ -66,6 +76,11 @@ struct shape_impl
std
::
size_t
element_space
()
const
{
if
(
m_dynamic
)
{
MIGRAPHX_THROW
(
"SHAPE: element_space() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -80,6 +95,11 @@ struct shape_impl
std
::
size_t
elements
()
const
{
if
(
m_dynamic
)
{
MIGRAPHX_THROW
(
"SHAPE: elements() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -137,6 +157,11 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
dims
)))
{
}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
...
...
@@ -150,9 +175,13 @@ shape shape::from_permutation(type_t t,
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
{
if
(
this
->
sub_shapes
().
empty
())
...
...
@@ -176,6 +205,9 @@ std::size_t shape::type_size() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
...
...
@@ -235,13 +267,23 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
});
}
bool
shape
::
dynamic
()
const
{
return
(
impl
->
m_dynamic
);
}
bool
shape
::
packed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
return
this
->
sub_shapes
().
empty
()
and
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
if
(
this
->
broadcasted
())
{
// TODO: Use a filter_iterator instead
...
...
@@ -261,6 +303,10 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
...
...
@@ -270,6 +316,10 @@ bool shape::broadcasted() const
bool
shape
::
scalar
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
// if any stride > 0, then accumulate will return false
return
this
->
sub_shapes
().
empty
()
and
...
...
test/op_shape_test.cpp
View file @
c0e18e78
...
...
@@ -958,7 +958,7 @@ TEST_CASE(multibroadcast)
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{}
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
{
...
...
test/shape_test.cpp
View file @
c0e18e78
...
...
@@ -42,6 +42,39 @@ TEST_CASE(test_shape_standard)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_dynamic_fixed
)
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
2
,
0
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
2
,
0
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
3
,
3
,
0
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
dims
};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
3
);
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
}
TEST_CASE
(
test_shape_dynamic_not_fixed
)
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
5
,
2
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
8
,
0
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
dims
};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
2
);
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
}
TEST_CASE
(
test_shape_packed
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment