Unverified Commit 09a9b6f7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve datasets benchmark (#4638)

parent 39d052c9
import argparse import argparse
import collections.abc
import contextlib import contextlib
import copy import copy
import inspect import inspect
...@@ -11,27 +12,41 @@ import sys ...@@ -11,27 +12,41 @@ import sys
import tempfile import tempfile
import time import time
import unittest.mock import unittest.mock
import warnings
import torch import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets from torchvision import datasets as legacy_datasets
from torchvision.datasets.vision import StandardTransform
from torchvision.prototype import datasets as new_datasets from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import ToTensor from torchvision.transforms import PILToTensor
def main(name, *, number): def main(name, *, number=5, temp_root=None, num_workers=0):
for benchmark in DATASET_BENCHMARKS: for benchmark in DATASET_BENCHMARKS:
if benchmark.name == name: if benchmark.name == name:
break break
else: else:
raise ValueError(f"No DatasetBenchmark available for dataset '{name}'") raise ValueError(f"No DatasetBenchmark available for dataset '{name}'")
print("legacy", "cold_start", Measurement.time(benchmark.legacy_cold_start, number=number)) print(
print("legacy", "warm_start", Measurement.time(benchmark.legacy_warm_start, number=number)) "legacy",
print("legacy", "iter", Measurement.iterations_per_time(benchmark.legacy_iteration, number=number)) "cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=number),
print("new", "cold_start", Measurement.time(benchmark.new_cold_start, number=number)) )
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter, number=number)) print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=number),
)
print(
"legacy",
"iter",
Measurement.iterations_per_time(benchmark.legacy_iteration(temp_root, num_workers=num_workers), number=number),
)
print("new", "cold_start", Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=number))
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter(num_workers=num_workers), number=number))
class DatasetBenchmark: class DatasetBenchmark:
...@@ -63,23 +78,29 @@ class DatasetBenchmark: ...@@ -63,23 +78,29 @@ class DatasetBenchmark:
self.prepare_legacy_root = prepare_legacy_root self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self): def new_dataset(self, *, num_workers=0):
return new_datasets.load(self.name, **self.new_config) return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def new_cold_start(self, *, num_workers):
def fn(timer):
with timer:
dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset))
return fn
def new_cold_start(self, timer): def new_iter(self, *, num_workers):
with timer: def fn(timer):
dataset = self.new_dataset() dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset)) num_samples = 0
def new_iter(self, timer): with timer:
dataset = self.new_dataset() for _ in dataset:
num_samples = 0 num_samples += 1
with timer: return num_samples
for _ in dataset:
num_samples += 1
return num_samples return fn
def suppress_output(self): def suppress_output(self):
@contextlib.contextmanager @contextlib.contextmanager
...@@ -90,12 +111,16 @@ class DatasetBenchmark: ...@@ -90,12 +111,16 @@ class DatasetBenchmark:
return context_manager() return context_manager()
def legacy_dataset(self, root, *, download=None): def legacy_dataset(self, root, *, num_workers=0, download=None):
special_options = self.legacy_special_options.copy() special_options = self.legacy_special_options.copy()
if "download" in special_options and download is not None: if "download" in special_options and download is not None:
special_options["download"] = download special_options["download"] = download
with self.suppress_output(): with self.suppress_output():
return self.legacy_cls(str(root), **self.legacy_config, **special_options) return DataLoader(
self.legacy_cls(str(root), **self.legacy_config, **special_options),
shuffle=True,
num_workers=num_workers,
)
@contextlib.contextmanager @contextlib.contextmanager
def patch_download_and_integrity_checks(self): def patch_download_and_integrity_checks(self):
...@@ -127,43 +152,61 @@ class DatasetBenchmark: ...@@ -127,43 +152,61 @@ class DatasetBenchmark:
return file_names return file_names
@contextlib.contextmanager @contextlib.contextmanager
def legacy_root(self): def legacy_root(self, temp_root):
new_root = new_datasets.home() / self.name new_root = new_datasets.home() / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp()) legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
warnings.warn(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories.",
RuntimeWarning,
)
for file_name in self._find_resource_file_names(): try:
(legacy_root / file_name).symlink_to(new_root / file_name) for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
if self.prepare_legacy_root: if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root) self.prepare_legacy_root(self, legacy_root)
with self.patch_download_and_integrity_checks(): with self.patch_download_and_integrity_checks():
try:
yield legacy_root yield legacy_root
finally: finally:
shutil.rmtree(legacy_root) shutil.rmtree(legacy_root)
def legacy_cold_start(self, timer): def legacy_cold_start(self, temp_root, *, num_workers):
with self.legacy_root() as root: def fn(timer):
with timer: with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root) with timer:
next(iter(dataset)) dataset = self.legacy_dataset(root, num_workers=num_workers)
next(iter(dataset))
def legacy_warm_start(self, timer): return fn
with self.legacy_root() as root:
self.legacy_dataset(root)
with timer:
dataset = self.legacy_dataset(root, download=False)
next(iter(dataset))
def legacy_iteration(self, timer): def legacy_warm_start(self, temp_root, *, num_workers):
with self.legacy_root() as root: def fn(timer):
dataset = self.legacy_dataset(root) with self.legacy_root(temp_root) as root:
with timer: self.legacy_dataset(root, num_workers=num_workers)
for _ in dataset: with timer:
pass dataset = self.legacy_dataset(root, num_workers=num_workers, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for _ in dataset:
pass
return len(dataset)
return len(dataset) return fn
def _find_legacy_cls(self): def _find_legacy_cls(self):
legacy_clss = { legacy_clss = {
...@@ -203,11 +246,11 @@ class DatasetBenchmark: ...@@ -203,11 +246,11 @@ class DatasetBenchmark:
special_options["download"] = True special_options["download"] = True
if "transform" in available_special_kwargs: if "transform" in available_special_kwargs:
special_options["transform"] = ToTensor() special_options["transform"] = PILToTensor()
if "target_transform" in available_special_kwargs: if "target_transform" in available_special_kwargs:
special_options["target_transform"] = torch.tensor special_options["target_transform"] = torch.tensor
elif "transforms" in available_special_kwargs: elif "transforms" in available_special_kwargs:
special_options["transforms"] = StandardTransform(ToTensor(), ToTensor()) special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor())
return special_options return special_options
...@@ -271,6 +314,21 @@ class Measurement: ...@@ -271,6 +314,21 @@ class Measurement:
return mean, std return mean, std
def no_split(config):
legacy_config = dict(config)
del legacy_config["split"]
return legacy_config
def bool_split(name="train"):
def legacy_config_map(config):
legacy_config = dict(config)
legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config
return legacy_config_map
def base_folder(rel_folder=None): def base_folder(rel_folder=None):
if rel_folder is None: if rel_folder is None:
...@@ -295,6 +353,29 @@ def base_folder(rel_folder=None): ...@@ -295,6 +353,29 @@ def base_folder(rel_folder=None):
return prepare_legacy_root return prepare_legacy_root
class JointTransform:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, *inputs):
if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence):
inputs = inputs[0]
if len(inputs) != len(self.transforms):
raise RuntimeError(
f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}."
)
return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(config):
legacy_config = no_split(config)
# The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation")
return legacy_config
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw") mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
...@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config): ...@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config):
DATASET_BENCHMARKS = [ DATASET_BENCHMARKS = [
DatasetBenchmark("caltech101", prepare_legacy_root=base_folder()), DatasetBenchmark(
DatasetBenchmark("caltech256", prepare_legacy_root=base_folder()), "caltech101",
legacy_config_map=caltech101_legacy_config_map,
prepare_legacy_root=base_folder(),
legacy_special_options_map=lambda config: dict(
download=True,
transform=PILToTensor(),
target_transform=JointTransform(torch.tensor, torch.tensor),
),
),
DatasetBenchmark(
"caltech256",
legacy_config_map=no_split,
prepare_legacy_root=base_folder(),
),
DatasetBenchmark( DatasetBenchmark(
"celeba", "celeba",
prepare_legacy_root=base_folder(), prepare_legacy_root=base_folder(),
...@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [ ...@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [
), ),
DatasetBenchmark( DatasetBenchmark(
"cifar10", "cifar10",
legacy_config_map=lambda config: dict(train=config.split == "train"), legacy_config_map=bool_split(),
), ),
DatasetBenchmark( DatasetBenchmark(
"cifar100", "cifar100",
legacy_config_map=lambda config: dict(train=config.split == "train"), legacy_config_map=bool_split(),
), ),
DatasetBenchmark( DatasetBenchmark(
"emnist", "emnist",
...@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [ ...@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [
), ),
legacy_special_options_map=lambda config: dict( legacy_special_options_map=lambda config: dict(
download=True, download=True,
transforms=StandardTransform(ToTensor(), torch.tensor if config.boundaries else ToTensor()), transforms=JointTransform(PILToTensor(), torch.tensor if config.boundaries else PILToTensor()),
), ),
), ),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection), DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
] ]
def parse_args(args=None): def parse_args(argv=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
prog="torchvision.prototype.datasets.benchmark.py",
parser.add_argument("name", type=str) description="Utility to benchmark new datasets against their legacy variants.",
parser.add_argument("--number", "-n", type=int, default=5, help="Number of iterations of each benchmark") formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"-n",
"--number",
type=int,
default=5,
help="Number of iterations of each benchmark.",
)
parser.add_argument(
"-t",
"--temp-root",
type=pathlib.Path,
help=(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process "
"and thus disable multi-processing."
),
)
return parser.parse_args(args or sys.argv[1:]) return parser.parse_args(argv or sys.argv[1:])
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
try: try:
main(args.name, number=args.number) main(args.name, number=args.number, temp_root=args.temp_root, num_workers=args.num_workers)
except Exception as error: except Exception as error:
msg = str(error) msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment