Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c0e18e78
"src/vscode:/vscode.git/clone" did not exist on "3541b5af8ffcd91c0143786c4bccfcd9fe78e9e1"
Commit
c0e18e78
authored
May 03, 2022
by
charlie
Browse files
Dynamic shape handling in shape object
parent
764273e4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
117 additions
and
1 deletion
+117
-1
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+5
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+21
-0
src/permutation.cpp
src/permutation.cpp
+7
-0
src/shape.cpp
src/shape.cpp
+50
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+1
-1
test/shape_test.cpp
test/shape_test.cpp
+33
-0
No files found.
src/include/migraphx/check_shapes.hpp
View file @
c0e18e78
...
...
@@ -48,6 +48,11 @@ struct check_shapes
return
end
-
begin
;
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template
<
class
...
Ts
>
const
check_shapes
&
has
(
Ts
...
ns
)
const
{
...
...
src/include/migraphx/shape.hpp
View file @
c0e18e78
...
...
@@ -59,6 +59,15 @@ struct shape
{
};
struct
dynamic_dimension
{
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
bool
is_fixed
()
const
{
return
min
==
max
;
};
bool
has_optimal
()
const
{
return
opt
!=
0
;
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
std
::
string
name
(
type_t
t
);
...
...
@@ -69,6 +78,8 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
...
...
@@ -93,6 +104,8 @@ struct shape
std
::
size_t
bytes
()
const
;
std
::
size_t
type_size
()
const
;
const
std
::
vector
<
dynamic_dimension
>&
dyn_dims
()
const
;
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
/// Map multiple indices to space index
...
...
@@ -115,17 +128,24 @@ struct shape
/// Returns true if the shape is packed with no padding
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool
transposed
()
const
;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool
broadcasted
()
const
;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
@@ -225,6 +245,7 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
/// size of the data buffer
std
::
size_t
element_space
()
const
;
private:
...
...
src/permutation.cpp
View file @
c0e18e78
...
...
@@ -13,11 +13,18 @@ shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
return
{
s
.
type
(),
reorder_dims
(
s
.
lens
(),
permutation
),
reorder_dims
(
s
.
strides
(),
permutation
)};
}
/*!
* Inverts the permutation using the less_than operator
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
)
{
return
sort_permutation
(
permutation
,
std
::
less
<>
{});
}
/*!
* Computes a permutation for the lengths based on decesending stride order.
* Compares the lengths if the strides are the same.
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
)
{
std
::
vector
<
std
::
int64_t
>
result
(
s
.
lens
().
size
());
...
...
src/shape.cpp
View file @
c0e18e78
...
...
@@ -3,6 +3,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
...
...
@@ -45,11 +46,20 @@ struct shape_impl
}
shape_impl
(
const
std
::
vector
<
shape
>&
subs
)
:
m_type
(
shape
::
tuple_type
),
m_shapes
(
subs
)
{}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
m_type
(
t
),
m_dynamic
(
true
),
m_dyn_dims
(
std
::
move
(
dims
))
{
}
shape
::
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
=
{};
std
::
vector
<
std
::
size_t
>
m_strides
=
{};
std
::
vector
<
shape
>
m_shapes
=
{};
bool
m_standard
=
false
;
bool
m_dynamic
=
false
;
std
::
vector
<
shape
::
dynamic_dimension
>
m_dyn_dims
=
{};
void
calculate_strides
()
{
...
...
@@ -66,6 +76,11 @@ struct shape_impl
std
::
size_t
element_space
()
const
{
if
(
m_dynamic
)
{
MIGRAPHX_THROW
(
"SHAPE: element_space() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -80,6 +95,11 @@ struct shape_impl
std
::
size_t
elements
()
const
{
if
(
m_dynamic
)
{
MIGRAPHX_THROW
(
"SHAPE: elements() called on dynamic shape"
);
}
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
m_lens
.
empty
())
return
0
;
...
...
@@ -137,6 +157,11 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
shape
::
dynamic_dimension
>
dims
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
dims
)))
{
}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
...
...
@@ -150,9 +175,13 @@ shape shape::from_permutation(type_t t,
}
shape
::
type_t
shape
::
type
()
const
{
return
impl
->
m_type
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
impl
->
m_lens
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
impl
->
m_strides
;
}
std
::
size_t
shape
::
elements
()
const
{
return
impl
->
elements
();
}
std
::
size_t
shape
::
bytes
()
const
{
if
(
this
->
sub_shapes
().
empty
())
...
...
@@ -176,6 +205,9 @@ std::size_t shape::type_size() const
this
->
visit_type
([
&
](
auto
as
)
{
n
=
as
.
size
();
});
return
n
;
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
size_t
shape
::
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
{
assert
(
l
.
size
()
<=
this
->
lens
().
size
());
...
...
@@ -235,13 +267,23 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
});
}
bool
shape
::
dynamic
()
const
{
return
(
impl
->
m_dynamic
);
}
bool
shape
::
packed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
return
this
->
sub_shapes
().
empty
()
and
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
if
(
this
->
broadcasted
())
{
// TODO: Use a filter_iterator instead
...
...
@@ -261,6 +303,10 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
...
...
@@ -270,6 +316,10 @@ bool shape::broadcasted() const
bool
shape
::
scalar
()
const
{
if
(
this
->
dynamic
())
{
return
false
;
}
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
// if any stride > 0, then accumulate will return false
return
this
->
sub_shapes
().
empty
()
and
...
...
test/op_shape_test.cpp
View file @
c0e18e78
...
...
@@ -958,7 +958,7 @@ TEST_CASE(multibroadcast)
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{}
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
{
...
...
test/shape_test.cpp
View file @
c0e18e78
...
...
@@ -42,6 +42,39 @@ TEST_CASE(test_shape_standard)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_dynamic_fixed
)
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
2
,
0
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
2
,
0
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
3
,
3
,
0
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
dims
};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
3
);
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
}
TEST_CASE
(
test_shape_dynamic_not_fixed
)
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dims
=
{};
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
5
,
2
});
dims
.
emplace_back
(
migraphx
::
shape
::
dynamic_dimension
{
2
,
8
,
0
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
dims
};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
dynamic
());
EXPECT
(
s
.
dyn_dims
().
size
()
==
2
);
EXPECT
(
not
s
.
dyn_dims
().
at
(
0
).
is_fixed
());
EXPECT
(
s
.
dyn_dims
().
at
(
0
).
has_optimal
());
}
TEST_CASE
(
test_shape_packed
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment