Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
09a9b6f7
Unverified
Commit
09a9b6f7
authored
Oct 19, 2021
by
Philip Meier
Committed by
GitHub
Oct 19, 2021
Browse files
improve datasets benchmark (#4638)
parent
39d052c9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
189 additions
and
66 deletions
+189
-66
torchvision/prototype/datasets/benchmark.py
torchvision/prototype/datasets/benchmark.py
+189
-66
No files found.
torchvision/prototype/datasets/benchmark.py
View file @
09a9b6f7
import
argparse
import
argparse
import
collections.abc
import
contextlib
import
contextlib
import
copy
import
copy
import
inspect
import
inspect
...
@@ -11,27 +12,41 @@ import sys
...
@@ -11,27 +12,41 @@ import sys
import
tempfile
import
tempfile
import
time
import
time
import
unittest.mock
import
unittest.mock
import
warnings
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data.dataloader_experimental
import
DataLoader2
from
torchvision
import
datasets
as
legacy_datasets
from
torchvision
import
datasets
as
legacy_datasets
from
torchvision.datasets.vision
import
StandardTransform
from
torchvision.prototype
import
datasets
as
new_datasets
from
torchvision.prototype
import
datasets
as
new_datasets
from
torchvision.transforms
import
ToTensor
from
torchvision.transforms
import
PIL
ToTensor
def
main
(
name
,
*
,
number
):
def
main
(
name
,
*
,
number
=
5
,
temp_root
=
None
,
num_workers
=
0
):
for
benchmark
in
DATASET_BENCHMARKS
:
for
benchmark
in
DATASET_BENCHMARKS
:
if
benchmark
.
name
==
name
:
if
benchmark
.
name
==
name
:
break
break
else
:
else
:
raise
ValueError
(
f
"No DatasetBenchmark available for dataset '
{
name
}
'"
)
raise
ValueError
(
f
"No DatasetBenchmark available for dataset '
{
name
}
'"
)
print
(
"legacy"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
legacy_cold_start
,
number
=
number
))
print
(
print
(
"legacy"
,
"warm_start"
,
Measurement
.
time
(
benchmark
.
legacy_warm_start
,
number
=
number
))
"legacy"
,
print
(
"legacy"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
,
number
=
number
))
"cold_start"
,
Measurement
.
time
(
benchmark
.
legacy_cold_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"legacy"
,
"warm_start"
,
Measurement
.
time
(
benchmark
.
legacy_warm_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"legacy"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
,
number
=
number
))
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
(
num_workers
=
num_workers
)
,
number
=
number
))
print
(
"new"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iter
,
number
=
number
))
print
(
"new"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iter
(
num_workers
=
num_workers
)
,
number
=
number
))
class
DatasetBenchmark
:
class
DatasetBenchmark
:
...
@@ -63,16 +78,20 @@ class DatasetBenchmark:
...
@@ -63,16 +78,20 @@ class DatasetBenchmark:
self
.
prepare_legacy_root
=
prepare_legacy_root
self
.
prepare_legacy_root
=
prepare_legacy_root
def
new_dataset
(
self
):
def
new_dataset
(
self
,
*
,
num_workers
=
0
):
return
new_datasets
.
load
(
self
.
name
,
**
self
.
new_config
)
return
DataLoader2
(
new_datasets
.
load
(
self
.
name
,
**
self
.
new_config
)
,
num_workers
=
num_workers
)
def
new_cold_start
(
self
,
timer
):
def
new_cold_start
(
self
,
*
,
num_workers
):
def
fn
(
timer
):
with
timer
:
with
timer
:
dataset
=
self
.
new_dataset
()
dataset
=
self
.
new_dataset
(
num_workers
=
num_workers
)
next
(
iter
(
dataset
))
next
(
iter
(
dataset
))
def
new_iter
(
self
,
timer
):
return
fn
dataset
=
self
.
new_dataset
()
def
new_iter
(
self
,
*
,
num_workers
):
def
fn
(
timer
):
dataset
=
self
.
new_dataset
(
num_workers
=
num_workers
)
num_samples
=
0
num_samples
=
0
with
timer
:
with
timer
:
...
@@ -81,6 +100,8 @@ class DatasetBenchmark:
...
@@ -81,6 +100,8 @@ class DatasetBenchmark:
return
num_samples
return
num_samples
return
fn
def
suppress_output
(
self
):
def
suppress_output
(
self
):
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
context_manager
():
def
context_manager
():
...
@@ -90,12 +111,16 @@ class DatasetBenchmark:
...
@@ -90,12 +111,16 @@ class DatasetBenchmark:
return
context_manager
()
return
context_manager
()
def
legacy_dataset
(
self
,
root
,
*
,
download
=
None
):
def
legacy_dataset
(
self
,
root
,
*
,
num_workers
=
0
,
download
=
None
):
special_options
=
self
.
legacy_special_options
.
copy
()
special_options
=
self
.
legacy_special_options
.
copy
()
if
"download"
in
special_options
and
download
is
not
None
:
if
"download"
in
special_options
and
download
is
not
None
:
special_options
[
"download"
]
=
download
special_options
[
"download"
]
=
download
with
self
.
suppress_output
():
with
self
.
suppress_output
():
return
self
.
legacy_cls
(
str
(
root
),
**
self
.
legacy_config
,
**
special_options
)
return
DataLoader
(
self
.
legacy_cls
(
str
(
root
),
**
self
.
legacy_config
,
**
special_options
),
shuffle
=
True
,
num_workers
=
num_workers
,
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
patch_download_and_integrity_checks
(
self
):
def
patch_download_and_integrity_checks
(
self
):
...
@@ -127,10 +152,20 @@ class DatasetBenchmark:
...
@@ -127,10 +152,20 @@ class DatasetBenchmark:
return
file_names
return
file_names
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
legacy_root
(
self
):
def
legacy_root
(
self
,
temp_root
):
new_root
=
new_datasets
.
home
()
/
self
.
name
new_root
=
new_datasets
.
home
()
/
self
.
name
legacy_root
=
pathlib
.
Path
(
tempfile
.
mkdtemp
())
legacy_root
=
pathlib
.
Path
(
tempfile
.
mkdtemp
(
dir
=
temp_root
))
if
os
.
stat
(
new_root
).
st_dev
!=
os
.
stat
(
legacy_root
).
st_dev
:
warnings
.
warn
(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories."
,
RuntimeWarning
,
)
try
:
for
file_name
in
self
.
_find_resource_file_names
():
for
file_name
in
self
.
_find_resource_file_names
():
(
legacy_root
/
file_name
).
symlink_to
(
new_root
/
file_name
)
(
legacy_root
/
file_name
).
symlink_to
(
new_root
/
file_name
)
...
@@ -138,33 +173,41 @@ class DatasetBenchmark:
...
@@ -138,33 +173,41 @@ class DatasetBenchmark:
self
.
prepare_legacy_root
(
self
,
legacy_root
)
self
.
prepare_legacy_root
(
self
,
legacy_root
)
with
self
.
patch_download_and_integrity_checks
():
with
self
.
patch_download_and_integrity_checks
():
try
:
yield
legacy_root
yield
legacy_root
finally
:
finally
:
shutil
.
rmtree
(
legacy_root
)
shutil
.
rmtree
(
legacy_root
)
def
legacy_cold_start
(
self
,
timer
):
def
legacy_cold_start
(
self
,
temp_root
,
*
,
num_workers
):
with
self
.
legacy_root
()
as
root
:
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
with
timer
:
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
)
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
next
(
iter
(
dataset
))
next
(
iter
(
dataset
))
def
legacy_warm_start
(
self
,
timer
):
return
fn
with
self
.
legacy_root
()
as
root
:
self
.
legacy_dataset
(
root
)
def
legacy_warm_start
(
self
,
temp_root
,
*
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
with
timer
:
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
,
download
=
False
)
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
,
download
=
False
)
next
(
iter
(
dataset
))
next
(
iter
(
dataset
))
def
legacy_iteration
(
self
,
timer
):
return
fn
with
self
.
legacy_root
()
as
root
:
dataset
=
self
.
legacy_dataset
(
root
)
def
legacy_iteration
(
self
,
temp_root
,
*
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
with
timer
:
with
timer
:
for
_
in
dataset
:
for
_
in
dataset
:
pass
pass
return
len
(
dataset
)
return
len
(
dataset
)
return
fn
def
_find_legacy_cls
(
self
):
def
_find_legacy_cls
(
self
):
legacy_clss
=
{
legacy_clss
=
{
name
.
lower
():
dataset_class
name
.
lower
():
dataset_class
...
@@ -203,11 +246,11 @@ class DatasetBenchmark:
...
@@ -203,11 +246,11 @@ class DatasetBenchmark:
special_options
[
"download"
]
=
True
special_options
[
"download"
]
=
True
if
"transform"
in
available_special_kwargs
:
if
"transform"
in
available_special_kwargs
:
special_options
[
"transform"
]
=
ToTensor
()
special_options
[
"transform"
]
=
PIL
ToTensor
()
if
"target_transform"
in
available_special_kwargs
:
if
"target_transform"
in
available_special_kwargs
:
special_options
[
"target_transform"
]
=
torch
.
tensor
special_options
[
"target_transform"
]
=
torch
.
tensor
elif
"transforms"
in
available_special_kwargs
:
elif
"transforms"
in
available_special_kwargs
:
special_options
[
"transforms"
]
=
Standard
Transform
(
ToTensor
(),
ToTensor
())
special_options
[
"transforms"
]
=
Joint
Transform
(
PIL
ToTensor
(),
PIL
ToTensor
())
return
special_options
return
special_options
...
@@ -271,6 +314,21 @@ class Measurement:
...
@@ -271,6 +314,21 @@ class Measurement:
return
mean
,
std
return
mean
,
std
def
no_split
(
config
):
legacy_config
=
dict
(
config
)
del
legacy_config
[
"split"
]
return
legacy_config
def
bool_split
(
name
=
"train"
):
def
legacy_config_map
(
config
):
legacy_config
=
dict
(
config
)
legacy_config
[
name
]
=
legacy_config
.
pop
(
"split"
)
==
"train"
return
legacy_config
return
legacy_config_map
def
base_folder
(
rel_folder
=
None
):
def
base_folder
(
rel_folder
=
None
):
if
rel_folder
is
None
:
if
rel_folder
is
None
:
...
@@ -295,6 +353,29 @@ def base_folder(rel_folder=None):
...
@@ -295,6 +353,29 @@ def base_folder(rel_folder=None):
return
prepare_legacy_root
return
prepare_legacy_root
class
JointTransform
:
def
__init__
(
self
,
*
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
*
inputs
):
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
,
collections
.
abc
.
Sequence
):
inputs
=
inputs
[
0
]
if
len
(
inputs
)
!=
len
(
self
.
transforms
):
raise
RuntimeError
(
f
"The number of inputs and transforms mismatches:
{
len
(
inputs
)
}
!=
{
len
(
self
.
transforms
)
}
."
)
return
tuple
(
transform
(
input
)
for
transform
,
input
in
zip
(
self
.
transforms
,
inputs
))
def
caltech101_legacy_config_map
(
config
):
legacy_config
=
no_split
(
config
)
# The new dataset always returns the category and annotation
legacy_config
[
"target_type"
]
=
(
"category"
,
"annotation"
)
return
legacy_config
mnist_base_folder
=
base_folder
(
lambda
benchmark
:
pathlib
.
Path
(
benchmark
.
legacy_cls
.
__name__
)
/
"raw"
)
mnist_base_folder
=
base_folder
(
lambda
benchmark
:
pathlib
.
Path
(
benchmark
.
legacy_cls
.
__name__
)
/
"raw"
)
...
@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config):
...
@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config):
DATASET_BENCHMARKS
=
[
DATASET_BENCHMARKS
=
[
DatasetBenchmark
(
"caltech101"
,
prepare_legacy_root
=
base_folder
()),
DatasetBenchmark
(
DatasetBenchmark
(
"caltech256"
,
prepare_legacy_root
=
base_folder
()),
"caltech101"
,
legacy_config_map
=
caltech101_legacy_config_map
,
prepare_legacy_root
=
base_folder
(),
legacy_special_options_map
=
lambda
config
:
dict
(
download
=
True
,
transform
=
PILToTensor
(),
target_transform
=
JointTransform
(
torch
.
tensor
,
torch
.
tensor
),
),
),
DatasetBenchmark
(
"caltech256"
,
legacy_config_map
=
no_split
,
prepare_legacy_root
=
base_folder
(),
),
DatasetBenchmark
(
DatasetBenchmark
(
"celeba"
,
"celeba"
,
prepare_legacy_root
=
base_folder
(),
prepare_legacy_root
=
base_folder
(),
...
@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [
...
@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [
),
),
DatasetBenchmark
(
DatasetBenchmark
(
"cifar10"
,
"cifar10"
,
legacy_config_map
=
lambda
config
:
dict
(
train
=
config
.
split
==
"train"
),
legacy_config_map
=
bool_split
(
),
),
),
DatasetBenchmark
(
DatasetBenchmark
(
"cifar100"
,
"cifar100"
,
legacy_config_map
=
lambda
config
:
dict
(
train
=
config
.
split
==
"train"
),
legacy_config_map
=
bool_split
(
),
),
),
DatasetBenchmark
(
DatasetBenchmark
(
"emnist"
,
"emnist"
,
...
@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [
...
@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [
),
),
legacy_special_options_map
=
lambda
config
:
dict
(
legacy_special_options_map
=
lambda
config
:
dict
(
download
=
True
,
download
=
True
,
transforms
=
Standard
Transform
(
ToTensor
(),
torch
.
tensor
if
config
.
boundaries
else
ToTensor
()),
transforms
=
Joint
Transform
(
PIL
ToTensor
(),
torch
.
tensor
if
config
.
boundaries
else
PIL
ToTensor
()),
),
),
),
),
DatasetBenchmark
(
"voc"
,
legacy_cls
=
legacy_datasets
.
VOCDetection
),
DatasetBenchmark
(
"voc"
,
legacy_cls
=
legacy_datasets
.
VOCDetection
),
]
]
def
parse_args
(
args
=
None
):
def
parse_args
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
(
prog
=
"torchvision.prototype.datasets.benchmark.py"
,
description
=
"Utility to benchmark new datasets against their legacy variants."
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"name"
,
type
=
str
)
parser
.
add_argument
(
"name"
,
help
=
"Name of the dataset to benchmark."
)
parser
.
add_argument
(
"--number"
,
"-n"
,
type
=
int
,
default
=
5
,
help
=
"Number of iterations of each benchmark"
)
parser
.
add_argument
(
"-n"
,
"--number"
,
type
=
int
,
default
=
5
,
help
=
"Number of iterations of each benchmark."
,
)
parser
.
add_argument
(
"-t"
,
"--temp-root"
,
type
=
pathlib
.
Path
,
help
=
(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser
.
add_argument
(
"-j"
,
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process "
"and thus disable multi-processing."
),
)
return
parser
.
parse_args
(
arg
s
or
sys
.
argv
[
1
:])
return
parser
.
parse_args
(
arg
v
or
sys
.
argv
[
1
:])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
try
:
try
:
main
(
args
.
name
,
number
=
args
.
number
)
main
(
args
.
name
,
number
=
args
.
number
,
temp_root
=
args
.
temp_root
,
num_workers
=
args
.
num_workers
)
except
Exception
as
error
:
except
Exception
as
error
:
msg
=
str
(
error
)
msg
=
str
(
error
)
print
(
msg
or
f
"Unspecified
{
type
(
error
)
}
was raised during execution."
,
file
=
sys
.
stderr
)
print
(
msg
or
f
"Unspecified
{
type
(
error
)
}
was raised during execution."
,
file
=
sys
.
stderr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment