Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
09a9b6f7
"src/vscode:/vscode.git/clone" did not exist on "e3dfaf82ad5101ae1b70dc5647d1165de0e41359"
Unverified
Commit
09a9b6f7
authored
Oct 19, 2021
by
Philip Meier
Committed by
GitHub
Oct 19, 2021
Browse files
improve datasets benchmark (#4638)
parent
39d052c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
189 additions
and
66 deletions
+189
-66
torchvision/prototype/datasets/benchmark.py
torchvision/prototype/datasets/benchmark.py
+189
-66
No files found.
torchvision/prototype/datasets/benchmark.py
View file @
09a9b6f7
import
argparse
import
collections.abc
import
contextlib
import
copy
import
inspect
...
...
@@ -11,27 +12,41 @@ import sys
import
tempfile
import
time
import
unittest.mock
import
warnings
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data.dataloader_experimental
import
DataLoader2
from
torchvision
import
datasets
as
legacy_datasets
from
torchvision.datasets.vision
import
StandardTransform
from
torchvision.prototype
import
datasets
as
new_datasets
from
torchvision.transforms
import
ToTensor
from
torchvision.transforms
import
PIL
ToTensor
def
main
(
name
,
*
,
number
):
def
main
(
name
,
*
,
number
=
5
,
temp_root
=
None
,
num_workers
=
0
):
for
benchmark
in
DATASET_BENCHMARKS
:
if
benchmark
.
name
==
name
:
break
else
:
raise
ValueError
(
f
"No DatasetBenchmark available for dataset '
{
name
}
'"
)
print
(
"legacy"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
legacy_cold_start
,
number
=
number
))
print
(
"legacy"
,
"warm_start"
,
Measurement
.
time
(
benchmark
.
legacy_warm_start
,
number
=
number
))
print
(
"legacy"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
,
number
=
number
))
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
,
number
=
number
))
print
(
"new"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iter
,
number
=
number
))
print
(
"legacy"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
legacy_cold_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"legacy"
,
"warm_start"
,
Measurement
.
time
(
benchmark
.
legacy_warm_start
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"legacy"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
legacy_iteration
(
temp_root
,
num_workers
=
num_workers
),
number
=
number
),
)
print
(
"new"
,
"cold_start"
,
Measurement
.
time
(
benchmark
.
new_cold_start
(
num_workers
=
num_workers
),
number
=
number
))
print
(
"new"
,
"iter"
,
Measurement
.
iterations_per_time
(
benchmark
.
new_iter
(
num_workers
=
num_workers
),
number
=
number
))
class
DatasetBenchmark
:
...
...
@@ -63,23 +78,29 @@ class DatasetBenchmark:
self
.
prepare_legacy_root
=
prepare_legacy_root
def
new_dataset
(
self
):
return
new_datasets
.
load
(
self
.
name
,
**
self
.
new_config
)
def
new_dataset
(
self
,
*
,
num_workers
=
0
):
return
DataLoader2
(
new_datasets
.
load
(
self
.
name
,
**
self
.
new_config
),
num_workers
=
num_workers
)
def
new_cold_start
(
self
,
*
,
num_workers
):
def
fn
(
timer
):
with
timer
:
dataset
=
self
.
new_dataset
(
num_workers
=
num_workers
)
next
(
iter
(
dataset
))
return
fn
def
new_
cold_start
(
self
,
tim
er
):
with
timer
:
dataset
=
self
.
new_dataset
()
n
ext
(
iter
(
dataset
))
def
new_
iter
(
self
,
*
,
num_work
er
s
):
def
fn
(
timer
)
:
dataset
=
self
.
new_dataset
(
num_workers
=
num_workers
)
n
um_samples
=
0
def
new_iter
(
self
,
timer
)
:
dataset
=
self
.
new_
dataset
()
num_samples
=
0
with
timer
:
for
_
in
dataset
:
num_samples
+
=
1
with
timer
:
for
_
in
dataset
:
num_samples
+=
1
return
num_samples
return
n
um_samples
return
f
n
def
suppress_output
(
self
):
@
contextlib
.
contextmanager
...
...
@@ -90,12 +111,16 @@ class DatasetBenchmark:
return
context_manager
()
def
legacy_dataset
(
self
,
root
,
*
,
download
=
None
):
def
legacy_dataset
(
self
,
root
,
*
,
num_workers
=
0
,
download
=
None
):
special_options
=
self
.
legacy_special_options
.
copy
()
if
"download"
in
special_options
and
download
is
not
None
:
special_options
[
"download"
]
=
download
with
self
.
suppress_output
():
return
self
.
legacy_cls
(
str
(
root
),
**
self
.
legacy_config
,
**
special_options
)
return
DataLoader
(
self
.
legacy_cls
(
str
(
root
),
**
self
.
legacy_config
,
**
special_options
),
shuffle
=
True
,
num_workers
=
num_workers
,
)
@
contextlib
.
contextmanager
def
patch_download_and_integrity_checks
(
self
):
...
...
@@ -127,43 +152,61 @@ class DatasetBenchmark:
return
file_names
@
contextlib
.
contextmanager
def
legacy_root
(
self
):
def
legacy_root
(
self
,
temp_root
):
new_root
=
new_datasets
.
home
()
/
self
.
name
legacy_root
=
pathlib
.
Path
(
tempfile
.
mkdtemp
())
legacy_root
=
pathlib
.
Path
(
tempfile
.
mkdtemp
(
dir
=
temp_root
))
if
os
.
stat
(
new_root
).
st_dev
!=
os
.
stat
(
legacy_root
).
st_dev
:
warnings
.
warn
(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories."
,
RuntimeWarning
,
)
for
file_name
in
self
.
_find_resource_file_names
():
(
legacy_root
/
file_name
).
symlink_to
(
new_root
/
file_name
)
try
:
for
file_name
in
self
.
_find_resource_file_names
():
(
legacy_root
/
file_name
).
symlink_to
(
new_root
/
file_name
)
if
self
.
prepare_legacy_root
:
self
.
prepare_legacy_root
(
self
,
legacy_root
)
if
self
.
prepare_legacy_root
:
self
.
prepare_legacy_root
(
self
,
legacy_root
)
with
self
.
patch_download_and_integrity_checks
():
try
:
with
self
.
patch_download_and_integrity_checks
():
yield
legacy_root
finally
:
shutil
.
rmtree
(
legacy_root
)
finally
:
shutil
.
rmtree
(
legacy_root
)
def
legacy_cold_start
(
self
,
timer
):
with
self
.
legacy_root
()
as
root
:
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
)
next
(
iter
(
dataset
))
def
legacy_cold_start
(
self
,
temp_root
,
*
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
next
(
iter
(
dataset
))
def
legacy_warm_start
(
self
,
timer
):
with
self
.
legacy_root
()
as
root
:
self
.
legacy_dataset
(
root
)
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
,
download
=
False
)
next
(
iter
(
dataset
))
return
fn
def
legacy_iteration
(
self
,
timer
):
with
self
.
legacy_root
()
as
root
:
dataset
=
self
.
legacy_dataset
(
root
)
with
timer
:
for
_
in
dataset
:
pass
def
legacy_warm_start
(
self
,
temp_root
,
*
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
with
timer
:
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
,
download
=
False
)
next
(
iter
(
dataset
))
return
fn
def
legacy_iteration
(
self
,
temp_root
,
*
,
num_workers
):
def
fn
(
timer
):
with
self
.
legacy_root
(
temp_root
)
as
root
:
dataset
=
self
.
legacy_dataset
(
root
,
num_workers
=
num_workers
)
with
timer
:
for
_
in
dataset
:
pass
return
len
(
dataset
)
return
len
(
dataset
)
return
fn
def
_find_legacy_cls
(
self
):
legacy_clss
=
{
...
...
@@ -203,11 +246,11 @@ class DatasetBenchmark:
special_options
[
"download"
]
=
True
if
"transform"
in
available_special_kwargs
:
special_options
[
"transform"
]
=
ToTensor
()
special_options
[
"transform"
]
=
PIL
ToTensor
()
if
"target_transform"
in
available_special_kwargs
:
special_options
[
"target_transform"
]
=
torch
.
tensor
elif
"transforms"
in
available_special_kwargs
:
special_options
[
"transforms"
]
=
Standard
Transform
(
ToTensor
(),
ToTensor
())
special_options
[
"transforms"
]
=
Joint
Transform
(
PIL
ToTensor
(),
PIL
ToTensor
())
return
special_options
...
...
@@ -271,6 +314,21 @@ class Measurement:
return
mean
,
std
def
no_split
(
config
):
legacy_config
=
dict
(
config
)
del
legacy_config
[
"split"
]
return
legacy_config
def
bool_split
(
name
=
"train"
):
def
legacy_config_map
(
config
):
legacy_config
=
dict
(
config
)
legacy_config
[
name
]
=
legacy_config
.
pop
(
"split"
)
==
"train"
return
legacy_config
return
legacy_config_map
def
base_folder
(
rel_folder
=
None
):
if
rel_folder
is
None
:
...
...
@@ -295,6 +353,29 @@ def base_folder(rel_folder=None):
return
prepare_legacy_root
class
JointTransform
:
def
__init__
(
self
,
*
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
*
inputs
):
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
,
collections
.
abc
.
Sequence
):
inputs
=
inputs
[
0
]
if
len
(
inputs
)
!=
len
(
self
.
transforms
):
raise
RuntimeError
(
f
"The number of inputs and transforms mismatches:
{
len
(
inputs
)
}
!=
{
len
(
self
.
transforms
)
}
."
)
return
tuple
(
transform
(
input
)
for
transform
,
input
in
zip
(
self
.
transforms
,
inputs
))
def
caltech101_legacy_config_map
(
config
):
legacy_config
=
no_split
(
config
)
# The new dataset always returns the category and annotation
legacy_config
[
"target_type"
]
=
(
"category"
,
"annotation"
)
return
legacy_config
mnist_base_folder
=
base_folder
(
lambda
benchmark
:
pathlib
.
Path
(
benchmark
.
legacy_cls
.
__name__
)
/
"raw"
)
...
...
@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config):
DATASET_BENCHMARKS
=
[
DatasetBenchmark
(
"caltech101"
,
prepare_legacy_root
=
base_folder
()),
DatasetBenchmark
(
"caltech256"
,
prepare_legacy_root
=
base_folder
()),
DatasetBenchmark
(
"caltech101"
,
legacy_config_map
=
caltech101_legacy_config_map
,
prepare_legacy_root
=
base_folder
(),
legacy_special_options_map
=
lambda
config
:
dict
(
download
=
True
,
transform
=
PILToTensor
(),
target_transform
=
JointTransform
(
torch
.
tensor
,
torch
.
tensor
),
),
),
DatasetBenchmark
(
"caltech256"
,
legacy_config_map
=
no_split
,
prepare_legacy_root
=
base_folder
(),
),
DatasetBenchmark
(
"celeba"
,
prepare_legacy_root
=
base_folder
(),
...
...
@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [
),
DatasetBenchmark
(
"cifar10"
,
legacy_config_map
=
lambda
config
:
dict
(
train
=
config
.
split
==
"train"
),
legacy_config_map
=
bool_split
(
),
),
DatasetBenchmark
(
"cifar100"
,
legacy_config_map
=
lambda
config
:
dict
(
train
=
config
.
split
==
"train"
),
legacy_config_map
=
bool_split
(
),
),
DatasetBenchmark
(
"emnist"
,
...
...
@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [
),
legacy_special_options_map
=
lambda
config
:
dict
(
download
=
True
,
transforms
=
Standard
Transform
(
ToTensor
(),
torch
.
tensor
if
config
.
boundaries
else
ToTensor
()),
transforms
=
Joint
Transform
(
PIL
ToTensor
(),
torch
.
tensor
if
config
.
boundaries
else
PIL
ToTensor
()),
),
),
DatasetBenchmark
(
"voc"
,
legacy_cls
=
legacy_datasets
.
VOCDetection
),
]
def
parse_args
(
args
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"name"
,
type
=
str
)
parser
.
add_argument
(
"--number"
,
"-n"
,
type
=
int
,
default
=
5
,
help
=
"Number of iterations of each benchmark"
)
def
parse_args
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
(
prog
=
"torchvision.prototype.datasets.benchmark.py"
,
description
=
"Utility to benchmark new datasets against their legacy variants."
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"name"
,
help
=
"Name of the dataset to benchmark."
)
parser
.
add_argument
(
"-n"
,
"--number"
,
type
=
int
,
default
=
5
,
help
=
"Number of iterations of each benchmark."
,
)
parser
.
add_argument
(
"-t"
,
"--temp-root"
,
type
=
pathlib
.
Path
,
help
=
(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser
.
add_argument
(
"-j"
,
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process "
"and thus disable multi-processing."
),
)
return
parser
.
parse_args
(
arg
s
or
sys
.
argv
[
1
:])
return
parser
.
parse_args
(
arg
v
or
sys
.
argv
[
1
:])
if
__name__
==
"__main__"
:
args
=
parse_args
()
try
:
main
(
args
.
name
,
number
=
args
.
number
)
main
(
args
.
name
,
number
=
args
.
number
,
temp_root
=
args
.
temp_root
,
num_workers
=
args
.
num_workers
)
except
Exception
as
error
:
msg
=
str
(
error
)
print
(
msg
or
f
"Unspecified
{
type
(
error
)
}
was raised during execution."
,
file
=
sys
.
stderr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment