"src/vscode:/vscode.git/clone" did not exist on "e3dfaf82ad5101ae1b70dc5647d1165de0e41359"
Unverified Commit 09a9b6f7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve datasets benchmark (#4638)

parent 39d052c9
import argparse
import collections.abc
import contextlib
import copy
import inspect
......@@ -11,27 +12,41 @@ import sys
import tempfile
import time
import unittest.mock
import warnings
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets
from torchvision.datasets.vision import StandardTransform
from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import ToTensor
from torchvision.transforms import PILToTensor
def main(name, *, number):
def main(name, *, number=5, temp_root=None, num_workers=0):
for benchmark in DATASET_BENCHMARKS:
if benchmark.name == name:
break
else:
raise ValueError(f"No DatasetBenchmark available for dataset '{name}'")
print("legacy", "cold_start", Measurement.time(benchmark.legacy_cold_start, number=number))
print("legacy", "warm_start", Measurement.time(benchmark.legacy_warm_start, number=number))
print("legacy", "iter", Measurement.iterations_per_time(benchmark.legacy_iteration, number=number))
print("new", "cold_start", Measurement.time(benchmark.new_cold_start, number=number))
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter, number=number))
print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=number),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=number),
)
print(
"legacy",
"iter",
Measurement.iterations_per_time(benchmark.legacy_iteration(temp_root, num_workers=num_workers), number=number),
)
print("new", "cold_start", Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=number))
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter(num_workers=num_workers), number=number))
class DatasetBenchmark:
......@@ -63,23 +78,29 @@ class DatasetBenchmark:
self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self):
return new_datasets.load(self.name, **self.new_config)
def new_dataset(self, *, num_workers=0):
return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def new_cold_start(self, *, num_workers):
def fn(timer):
with timer:
dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset))
return fn
def new_cold_start(self, timer):
with timer:
dataset = self.new_dataset()
next(iter(dataset))
def new_iter(self, *, num_workers):
def fn(timer):
dataset = self.new_dataset(num_workers=num_workers)
num_samples = 0
def new_iter(self, timer):
dataset = self.new_dataset()
num_samples = 0
with timer:
for _ in dataset:
num_samples += 1
with timer:
for _ in dataset:
num_samples += 1
return num_samples
return num_samples
return fn
def suppress_output(self):
@contextlib.contextmanager
......@@ -90,12 +111,16 @@ class DatasetBenchmark:
return context_manager()
def legacy_dataset(self, root, *, download=None):
def legacy_dataset(self, root, *, num_workers=0, download=None):
special_options = self.legacy_special_options.copy()
if "download" in special_options and download is not None:
special_options["download"] = download
with self.suppress_output():
return self.legacy_cls(str(root), **self.legacy_config, **special_options)
return DataLoader(
self.legacy_cls(str(root), **self.legacy_config, **special_options),
shuffle=True,
num_workers=num_workers,
)
@contextlib.contextmanager
def patch_download_and_integrity_checks(self):
......@@ -127,43 +152,61 @@ class DatasetBenchmark:
return file_names
@contextlib.contextmanager
def legacy_root(self):
def legacy_root(self, temp_root):
new_root = new_datasets.home() / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp())
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
warnings.warn(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories.",
RuntimeWarning,
)
for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
try:
for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root)
if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root)
with self.patch_download_and_integrity_checks():
try:
with self.patch_download_and_integrity_checks():
yield legacy_root
finally:
shutil.rmtree(legacy_root)
finally:
shutil.rmtree(legacy_root)
def legacy_cold_start(self, timer):
with self.legacy_root() as root:
with timer:
dataset = self.legacy_dataset(root)
next(iter(dataset))
def legacy_cold_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers)
next(iter(dataset))
def legacy_warm_start(self, timer):
with self.legacy_root() as root:
self.legacy_dataset(root)
with timer:
dataset = self.legacy_dataset(root, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, timer):
with self.legacy_root() as root:
dataset = self.legacy_dataset(root)
with timer:
for _ in dataset:
pass
def legacy_warm_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
self.legacy_dataset(root, num_workers=num_workers)
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for _ in dataset:
pass
return len(dataset)
return len(dataset)
return fn
def _find_legacy_cls(self):
legacy_clss = {
......@@ -203,11 +246,11 @@ class DatasetBenchmark:
special_options["download"] = True
if "transform" in available_special_kwargs:
special_options["transform"] = ToTensor()
special_options["transform"] = PILToTensor()
if "target_transform" in available_special_kwargs:
special_options["target_transform"] = torch.tensor
elif "transforms" in available_special_kwargs:
special_options["transforms"] = StandardTransform(ToTensor(), ToTensor())
special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor())
return special_options
......@@ -271,6 +314,21 @@ class Measurement:
return mean, std
def no_split(config):
legacy_config = dict(config)
del legacy_config["split"]
return legacy_config
def bool_split(name="train"):
def legacy_config_map(config):
legacy_config = dict(config)
legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config
return legacy_config_map
def base_folder(rel_folder=None):
if rel_folder is None:
......@@ -295,6 +353,29 @@ def base_folder(rel_folder=None):
return prepare_legacy_root
class JointTransform:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, *inputs):
if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence):
inputs = inputs[0]
if len(inputs) != len(self.transforms):
raise RuntimeError(
f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}."
)
return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(config):
legacy_config = no_split(config)
# The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation")
return legacy_config
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
......@@ -323,8 +404,21 @@ def qmnist_legacy_config_map(config):
DATASET_BENCHMARKS = [
DatasetBenchmark("caltech101", prepare_legacy_root=base_folder()),
DatasetBenchmark("caltech256", prepare_legacy_root=base_folder()),
DatasetBenchmark(
"caltech101",
legacy_config_map=caltech101_legacy_config_map,
prepare_legacy_root=base_folder(),
legacy_special_options_map=lambda config: dict(
download=True,
transform=PILToTensor(),
target_transform=JointTransform(torch.tensor, torch.tensor),
),
),
DatasetBenchmark(
"caltech256",
legacy_config_map=no_split,
prepare_legacy_root=base_folder(),
),
DatasetBenchmark(
"celeba",
prepare_legacy_root=base_folder(),
......@@ -336,11 +430,11 @@ DATASET_BENCHMARKS = [
),
DatasetBenchmark(
"cifar10",
legacy_config_map=lambda config: dict(train=config.split == "train"),
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"cifar100",
legacy_config_map=lambda config: dict(train=config.split == "train"),
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"emnist",
......@@ -376,27 +470,56 @@ DATASET_BENCHMARKS = [
),
legacy_special_options_map=lambda config: dict(
download=True,
transforms=StandardTransform(ToTensor(), torch.tensor if config.boundaries else ToTensor()),
transforms=JointTransform(PILToTensor(), torch.tensor if config.boundaries else PILToTensor()),
),
),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
]
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("name", type=str)
parser.add_argument("--number", "-n", type=int, default=5, help="Number of iterations of each benchmark")
def parse_args(argv=None):
parser = argparse.ArgumentParser(
prog="torchvision.prototype.datasets.benchmark.py",
description="Utility to benchmark new datasets against their legacy variants.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"-n",
"--number",
type=int,
default=5,
help="Number of iterations of each benchmark.",
)
parser.add_argument(
"-t",
"--temp-root",
type=pathlib.Path,
help=(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process "
"and thus disable multi-processing."
),
)
return parser.parse_args(args or sys.argv[1:])
return parser.parse_args(argv or sys.argv[1:])
if __name__ == "__main__":
args = parse_args()
try:
main(args.name, number=args.number)
main(args.name, number=args.number, temp_root=args.temp_root, num_workers=args.num_workers)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment