Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13a8bcaa
Commit
13a8bcaa
authored
Jul 13, 2022
by
charlie
Browse files
Merge branch 'dyn_check_shapes' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv
parents
d5636acd
d6afa9e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
13 deletions
+93
-13
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+81
-4
src/program.cpp
src/program.cpp
+12
-9
No files found.
src/include/migraphx/check_shapes.hpp
View file @
13a8bcaa
...
@@ -38,22 +38,34 @@ struct check_shapes
...
@@ -38,22 +38,34 @@ struct check_shapes
const
shape
*
begin
;
const
shape
*
begin
;
const
shape
*
end
;
const
shape
*
end
;
const
std
::
string
name
;
const
std
::
string
name
;
const
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
)
:
begin
(
b
),
end
(
e
),
name
(
n
)
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
{
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
())
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
())
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
())
,
dynamic_allowed
(
d
)
{
{
}
}
~
check_shapes
()
{
if
(
not
dynamic_allowed
and
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
{
std
::
cerr
<<
prefix
()
<<
"Dynamic shapes not supported"
<<
std
::
endl
;
std
::
abort
();
}
}
std
::
string
prefix
()
const
std
::
string
prefix
()
const
{
{
if
(
name
.
empty
())
if
(
name
.
empty
())
...
@@ -92,6 +104,11 @@ struct check_shapes
...
@@ -92,6 +104,11 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
...
@@ -104,6 +121,11 @@ struct check_shapes
...
@@ -104,6 +121,11 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
...
@@ -117,6 +139,11 @@ struct check_shapes
...
@@ -117,6 +139,11 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
begin
!=
nullptr
);
...
@@ -130,6 +157,9 @@ struct check_shapes
...
@@ -130,6 +157,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same shape.
*/
const
check_shapes
&
same_shape
()
const
const
check_shapes
&
same_shape
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
...
@@ -137,6 +167,9 @@ struct check_shapes
...
@@ -137,6 +167,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same type.
*/
const
check_shapes
&
same_type
()
const
const
check_shapes
&
same_type
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
...
@@ -144,6 +177,9 @@ struct check_shapes
...
@@ -144,6 +177,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same lens.
*/
const
check_shapes
&
same_dims
()
const
const
check_shapes
&
same_dims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
...
@@ -151,6 +187,9 @@ struct check_shapes
...
@@ -151,6 +187,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same number of dimensions.
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
...
@@ -158,6 +197,9 @@ struct check_shapes
...
@@ -158,6 +197,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are standard.
*/
const
check_shapes
&
standard
()
const
const
check_shapes
&
standard
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
...
@@ -165,6 +207,9 @@ struct check_shapes
...
@@ -165,6 +207,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are standard or scalar.
*/
const
check_shapes
&
standard_or_scalar
()
const
const
check_shapes
&
standard_or_scalar
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
...
@@ -172,6 +217,9 @@ struct check_shapes
...
@@ -172,6 +217,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are packed.
*/
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
...
@@ -179,6 +227,9 @@ struct check_shapes
...
@@ -179,6 +227,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are packed or broadcasted.
*/
const
check_shapes
&
packed_or_broadcasted
()
const
const
check_shapes
&
packed_or_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
...
@@ -186,6 +237,9 @@ struct check_shapes
...
@@ -186,6 +237,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are tuples.
*/
const
check_shapes
&
tuple_type
()
const
const
check_shapes
&
tuple_type
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
...
@@ -193,6 +247,9 @@ struct check_shapes
...
@@ -193,6 +247,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are not transposed.
*/
const
check_shapes
&
not_transposed
()
const
const
check_shapes
&
not_transposed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
...
@@ -200,6 +257,9 @@ struct check_shapes
...
@@ -200,6 +257,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes are not broadcasted.
*/
const
check_shapes
&
not_broadcasted
()
const
const
check_shapes
&
not_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
...
@@ -207,6 +267,10 @@ struct check_shapes
...
@@ -207,6 +267,10 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
...
@@ -214,6 +278,9 @@ struct check_shapes
...
@@ -214,6 +278,9 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const
check_shapes
&
batch_not_transposed
()
const
const
check_shapes
&
batch_not_transposed
()
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
...
@@ -242,6 +309,16 @@ struct check_shapes
...
@@ -242,6 +309,16 @@ struct check_shapes
return
std
::
all_of
(
begin
,
end
,
p
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
}
template
<
class
Predicate
>
bool
any_of
(
Predicate
p
)
const
{
if
(
begin
==
end
)
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
const
shape
*
get
(
long
i
)
const
const
shape
*
get
(
long
i
)
const
{
{
if
(
i
>=
size
())
if
(
i
>=
size
())
...
...
src/program.cpp
View file @
13a8bcaa
...
@@ -740,11 +740,13 @@ void program::perf_report(std::ostream& os,
...
@@ -740,11 +740,13 @@ void program::perf_report(std::ostream& os,
double
overhead_percent
=
overhead_time
*
100.0
/
total_time
;
double
overhead_percent
=
overhead_time
*
100.0
/
total_time
;
double
total_instruction_time
=
0.0
;
double
total_instruction_time
=
0.0
;
std
::
unordered_map
<
std
::
string
,
double
>
op_times
;
std
::
unordered_map
<
std
::
string
,
double
>
op_times
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
op_n
;
for
(
auto
&&
p
:
ins_vec
)
for
(
auto
&&
p
:
ins_vec
)
{
{
double
avg
=
common_average
(
p
.
second
);
double
avg
=
common_average
(
p
.
second
);
op_times
[
perf_group
(
p
.
first
->
get_operator
())]
+=
avg
;
op_times
[
perf_group
(
p
.
first
->
get_operator
())]
+=
avg
;
total_instruction_time
+=
avg
;
total_instruction_time
+=
avg
;
op_n
[
perf_group
(
p
.
first
->
get_operator
())]
++
;
}
}
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_percent
=
calculate_overhead_time
*
100.0
/
total_time
;
double
calculate_overhead_percent
=
calculate_overhead_time
*
100.0
/
total_time
;
...
@@ -765,18 +767,19 @@ void program::perf_report(std::ostream& os,
...
@@ -765,18 +767,19 @@ void program::perf_report(std::ostream& os,
os
<<
std
::
endl
;
os
<<
std
::
endl
;
os
<<
"Summary:"
<<
std
::
endl
;
os
<<
"Summary:"
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
op_times_sorted
;
std
::
vector
<
std
::
tuple
<
double
,
std
::
size_t
,
std
::
string
>>
op_times_sorted
;
std
::
transform
(
op_times
.
begin
(),
std
::
transform
(
op_times
.
end
(),
op_times
.
begin
(),
op_times
.
end
(),
std
::
back_inserter
(
op_times_sorted
),
[
&
](
auto
p
)
{
std
::
back_inserter
(
op_times_sorted
),
auto
&&
name
=
p
.
first
;
[](
auto
p
)
{
return
std
::
make_pair
(
p
.
second
,
p
.
first
);
});
return
std
::
make_tuple
(
p
.
second
,
op_n
.
at
(
name
),
name
);
});
std
::
sort
(
op_times_sorted
.
begin
(),
op_times_sorted
.
end
(),
std
::
greater
<>
{});
std
::
sort
(
op_times_sorted
.
begin
(),
op_times_sorted
.
end
(),
std
::
greater
<>
{});
for
(
auto
&&
p
:
op_times_sorted
)
for
(
auto
&&
[
avg
,
nn
,
name
]
:
op_times_sorted
)
{
{
auto
&&
name
=
p
.
second
;
double
avg
=
p
.
first
;
double
percent
=
std
::
ceil
(
100.0
*
avg
/
total_instruction_time
);
double
percent
=
std
::
ceil
(
100.0
*
avg
/
total_instruction_time
);
os
<<
name
<<
": "
<<
avg
<<
"ms, "
<<
percent
<<
"%"
<<
std
::
endl
;
double
per_ins
=
avg
/
nn
;
os
<<
name
<<
": "
<<
avg
<<
"ms / "
<<
nn
<<
" = "
<<
per_ins
<<
"ms, "
<<
percent
<<
"%"
<<
std
::
endl
;
}
}
os
<<
std
::
endl
;
os
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment