Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
845391cd
Unverified
Commit
845391cd
authored
Nov 18, 2021
by
Philip Meier
Committed by
GitHub
Nov 18, 2021
Browse files
enable selective benchmarks (#4960)
Co-authored-by:
Francisco Massa
<
fvsmassa@gmail.com
>
parent
408c9bea
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
129 additions
and
51 deletions
+129
-51
torchvision/prototype/datasets/benchmark.py
torchvision/prototype/datasets/benchmark.py
+129
-51
No files found.
torchvision/prototype/datasets/benchmark.py
View file @
845391cd
...
...
@@ -24,31 +24,58 @@ from torchvision.prototype import datasets as new_datasets
from
torchvision.transforms
import
PILToTensor
def
main
(
name
,
*
,
number
=
5
,
temp_root
=
None
,
num_workers
=
0
):
def
main
(
name
,
*
,
legacy
=
True
,
new
=
True
,
start
=
True
,
iteration
=
True
,
num_starts
=
3
,
num_samples
=
10_000
,
temp_root
=
None
,
num_workers
=
0
,
):
for
benchmark
in
DATASET_BENCHMARKS
:
if
benchmark
.
name
==
name
:
break
else
:
raise
ValueError
(
f
"No DatasetBenchmark available for dataset '
{
name
}
'"
)
if
legacy
and
start
:
print
(
"legacy"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
legacy_cold_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
num
ber
),
Measurement
.
time
(
benchmark
.
legacy_cold_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
num
_starts
),
)
print
(
"legacy"
,
"warm_start"
,
Measurement
.
time
(
benchmark
.
legacy_warm_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
num
ber
),
Measurement
.
time
(
benchmark
.
legacy_warm_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
num
_starts
),
)
if
legacy
and
iteration
:
print
(
"legacy"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
"iteration"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
(
temp_root
,
num_workers
=
num_workers
,
num_samples
=
num_samples
)
),
)
if
new
and
start
:
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
(
num_workers
=
num_workers
),
number
=
num_starts
),
)
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
(
num_workers
=
num_workers
),
number
=
number
))
print
(
"new"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iter
(
num_workers
=
num_workers
),
number
=
number
))
if
new
and
iteration
:
print
(
"new"
,
"iteration"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iteration
(
num_workers
=
num_workers
,
num_samples
=
num_samples
)),
)
class
DatasetBenchmark
:
...
...
@@ -91,16 +118,17 @@ class DatasetBenchmark:
return
fn
def
new_iter
(
self
,
*
,
num_workers
):
def
new_iter
ation
(
self
,
*
,
num_samples
,
num_workers
):
def
fn
(
timer
):
dataset
=
self
.
new_dataset
(
num_workers
=
num_workers
)
num_samples
=
0
num_sample
=
0
with
timer
:
for
_
in
dataset
:
num_samples
+=
1
num_sample
+=
1
if
num_sample
==
num_samples
:
break
return
num_sample
s
return
num_sample
return
fn
...
...
@@ -155,7 +183,7 @@ class DatasetBenchmark:
@
contextlib
.
contextmanager
def
legacy_root
(
self
,
temp_root
):
new_root
=
new_datasets
.
home
()
/
self
.
name
new_root
=
pathlib
.
Path
(
new_datasets
.
home
()
)
/
self
.
name
legacy_root
=
pathlib
.
Path
(
tempfile
.
mkdtemp
(
dir
=
temp_root
))
if
os
.
stat
(
new_root
).
st_dev
!=
os
.
stat
(
legacy_root
).
st_dev
:
...
...
@@ -198,15 +226,16 @@ class DatasetBenchmark:
return
fn
def
legacy_iteration
(
self
,
temp_root
,
*
,
num_workers
):
def
legacy_iteration
(
self
,
temp_root
,
*
,
num_samples
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
with
timer
:
for
_
in
dataset
:
pass
for
num_sample
,
_
in
enumerate
(
dataset
,
1
):
if
num_sample
==
num_samples
:
break
return
len
(
dataset
)
return
num_sample
return
fn
...
...
@@ -261,24 +290,14 @@ class Measurement:
@
classmethod
def
time
(
cls
,
fn
,
*
,
number
):
results
=
Measurement
.
_timeit
(
fn
,
number
=
number
)
times
=
torch
.
tensor
(
tuple
(
zip
(
*
results
))[
1
])
mean
,
std
=
Measurement
.
_compute_mean_and_std
(
times
)
# TODO format that into engineering format
return
f
"
{
mean
:.
3
g
}
±
{
std
:.
3
g
}
s"
return
cls
.
_format
(
times
,
unit
=
"s"
)
@
classmethod
def
iterations_per_time
(
cls
,
fn
,
*
,
number
):
outputs
,
times
=
zip
(
*
Measurement
.
_timeit
(
fn
,
number
=
number
))
num_samples
=
outputs
[
0
]
assert
all
(
other_num_samples
==
num_samples
for
other_num_samples
in
outputs
[
1
:])
iterations_per_time
=
torch
.
tensor
(
num_samples
)
/
torch
.
tensor
(
times
)
mean
,
std
=
Measurement
.
_compute_mean_and_std
(
iterations_per_time
)
# TODO format that into engineering format
return
f
"
{
mean
:.
1
f
}
±
{
std
:.
1
f
}
it/s"
def
iterations_per_time
(
cls
,
fn
):
num_samples
,
time
=
Measurement
.
_timeit
(
fn
,
number
=
1
)[
0
]
iterations_per_second
=
torch
.
tensor
(
num_samples
)
/
torch
.
tensor
(
time
)
return
cls
.
_format
(
iterations_per_second
,
unit
=
"it/s"
)
class
Timer
:
def
__init__
(
self
):
...
...
@@ -300,7 +319,7 @@ class Measurement:
return
self
.
_stop
-
self
.
_start
@
classmethod
def
_timeit
(
cls
,
fn
,
*
,
number
):
def
_timeit
(
cls
,
fn
,
number
):
results
=
[]
for
_
in
range
(
number
):
timer
=
cls
.
Timer
()
...
...
@@ -308,9 +327,19 @@ class Measurement:
results
.
append
((
output
,
timer
.
delta
))
return
results
@
classmethod
def
_format
(
cls
,
measurements
,
*
,
unit
):
measurements
=
torch
.
as_tensor
(
measurements
).
to
(
torch
.
float64
).
flatten
()
if
measurements
.
numel
()
==
1
:
# TODO format that into engineering format
return
f
"
{
float
(
measurements
):.
3
f
}
{
unit
}
"
mean
,
std
=
Measurement
.
_compute_mean_and_std
(
measurements
)
# TODO format that into engineering format
return
f
"
{
mean
:.
3
f
}
±
{
std
:.
3
f
}
{
unit
}
"
@
classmethod
def
_compute_mean_and_std
(
cls
,
t
):
t
=
t
.
flatten
()
mean
=
float
(
t
.
mean
())
std
=
float
(
t
.
std
(
0
,
unbiased
=
t
.
numel
()
>
1
))
return
mean
,
std
...
...
@@ -476,6 +505,7 @@ DATASET_BENCHMARKS = [
),
),
DatasetBenchmark
(
"voc"
,
legacy_cls
=
legacy_datasets
.
VOCDetection
),
DatasetBenchmark
(
"imagenet"
,
legacy_cls
=
legacy_datasets
.
ImageNet
),
]
...
...
@@ -487,13 +517,51 @@ def parse_args(argv=None):
)
parser
.
add_argument
(
"name"
,
help
=
"Name of the dataset to benchmark."
)
parser
.
add_argument
(
"-n"
,
"--num
ber
"
,
"--num
-starts
"
,
type
=
int
,
default
=
5
,
help
=
"Number of
iteration
s of each benchmark."
,
default
=
3
,
help
=
"Number of
warm and cold start
s of each benchmark.
Default to 3.
"
,
)
parser
.
add_argument
(
"-N"
,
"--num-samples"
,
type
=
int
,
default
=
10_000
,
help
=
"Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000."
,
)
parser
.
add_argument
(
"--nl"
,
"--no-legacy"
,
dest
=
"legacy"
,
action
=
"store_false"
,
help
=
"Skip legacy benchmarks."
,
)
parser
.
add_argument
(
"--nn"
,
"--no-new"
,
dest
=
"new"
,
action
=
"store_false"
,
help
=
"Skip new benchmarks."
,
)
parser
.
add_argument
(
"--ns"
,
"--no-start"
,
dest
=
"start"
,
action
=
"store_false"
,
help
=
"Skip start benchmarks."
,
)
parser
.
add_argument
(
"--ni"
,
"--no-iteration"
,
dest
=
"iteration"
,
action
=
"store_false"
,
help
=
"Skip iteration benchmarks."
,
)
parser
.
add_argument
(
"-t"
,
"--temp-root"
,
...
...
@@ -509,8 +577,8 @@ def parse_args(argv=None):
type
=
int
,
default
=
0
,
help
=
(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main
process
"
"and thus disable multi-processing."
"Number of subprocesses used to load the data. Setting this to 0
(default)
will load all data in the main "
"
process
and thus disable multi-processing."
),
)
...
...
@@ -521,7 +589,17 @@ if __name__ == "__main__":
args
=
parse_args
()
try
:
main
(
args
.
name
,
number
=
args
.
number
,
temp_root
=
args
.
temp_root
,
num_workers
=
args
.
num_workers
)
main
(
args
.
name
,
legacy
=
args
.
legacy
,
new
=
args
.
new
,
start
=
args
.
start
,
iteration
=
args
.
iteration
,
num_starts
=
args
.
num_starts
,
num_samples
=
args
.
num_samples
,
temp_root
=
args
.
temp_root
,
num_workers
=
args
.
num_workers
,
)
except
Exception
as
error
:
msg
=
str
(
error
)
print
(
msg
or
f
"Unspecified
{
type
(
error
)
}
was raised during execution."
,
file
=
sys
.
stderr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment