Unverified Commit 39cf02a6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve COCO prototype (#4650)

* improve COCO prototype

* test 2017 annotations

* add option to include captions

* fix categories and add tests

* cleanup

* add correct image size to bounding boxes

* fix annotation collation

* appease mypy

* add benchmark

* always use image as reference

* another refactor

* add support for segmentations

* add support for segmentations

* fix CI dependencies
parent 3d8723d5
...@@ -351,7 +351,7 @@ jobs: ...@@ -351,7 +351,7 @@ jobs:
- install_torchvision - install_torchvision
- install_prototype_dependencies - install_prototype_dependencies
- pip_install: - pip_install:
args: scipy args: scipy pycocotools
descr: Install optional dependencies descr: Install optional dependencies
- run: - run:
name: Enable prototype tests name: Enable prototype tests
......
...@@ -351,7 +351,7 @@ jobs: ...@@ -351,7 +351,7 @@ jobs:
- install_torchvision - install_torchvision
- install_prototype_dependencies - install_prototype_dependencies
- pip_install: - pip_install:
args: scipy args: scipy pycocotools
descr: Install optional dependencies descr: Install optional dependencies
- run: - run:
name: Enable prototype tests name: Enable prototype tests
......
import functools import functools
import gzip import gzip
import json
import lzma import lzma
import pathlib import pathlib
import pickle import pickle
...@@ -8,6 +9,7 @@ from collections import defaultdict ...@@ -8,6 +9,7 @@ from collections import defaultdict
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import numpy as np import numpy as np
import PIL.Image
import pytest import pytest
import torch import torch
from datasets_utils import create_image_folder, make_tar, make_zip from datasets_utils import create_image_folder, make_tar, make_zip
...@@ -18,7 +20,9 @@ from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DEC ...@@ -18,7 +20,9 @@ from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DEC
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import add_suggestion from torchvision.prototype.utils._internal import add_suggestion
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
__all__ = ["load"] __all__ = ["load"]
...@@ -490,3 +494,113 @@ def imagenet(info, root, config): ...@@ -490,3 +494,113 @@ def imagenet(info, root, config):
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
return num_samples return num_samples
class CocoMockData:
@classmethod
def _make_images_archive(cls, root, name, *, num_samples):
image_paths = create_image_folder(
root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples
)
images_meta = []
for path in image_paths:
with PIL.Image.open(path) as image:
width, height = image.size
images_meta.append(dict(file_name=path.name, id=int(path.stem), width=width, height=height))
make_zip(root, f"{name}.zip")
return images_meta
@classmethod
def _make_annotations_json(
cls,
root,
name,
*,
images_meta,
fn,
):
num_anns_per_image = torch.randint(1, 5, (len(images_meta),))
num_anns_total = int(num_anns_per_image.sum())
ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)])
anns_meta = []
for image_meta, num_anns in zip(images_meta, num_anns_per_image):
for _ in range(num_anns):
ann_id = int(next(ann_ids_iter))
anns_meta.append(dict(fn(ann_id, image_meta), id=ann_id, image_id=image_meta["id"]))
anns_meta.sort(key=lambda ann: ann["id"])
with open(root / name, "w") as file:
json.dump(dict(images=images_meta, annotations=anns_meta), file)
return num_anns_per_image
@staticmethod
def _make_instances_data(ann_id, image_meta):
def make_rle_segmentation():
height, width = image_meta["height"], image_meta["width"]
numel = height * width
counts = []
while sum(counts) <= numel:
counts.append(int(torch.randint(5, 8, ())))
if sum(counts) > numel:
counts[-1] -= sum(counts) - numel
return dict(counts=counts, size=[height, width])
return dict(
segmentation=make_rle_segmentation(),
bbox=make_tensor((4,), dtype=torch.float32, low=0).tolist(),
iscrowd=True,
area=float(make_scalar(dtype=torch.float32)),
category_id=int(make_scalar(dtype=torch.int64)),
)
@staticmethod
def _make_captions_data(ann_id, image_meta):
return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.")
@classmethod
def _make_annotations(cls, root, name, *, images_meta):
num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64)
for annotations, fn in (
("instances", cls._make_instances_data),
("captions", cls._make_captions_data),
):
num_anns_per_image += cls._make_annotations_json(
root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn
)
return int(num_anns_per_image.sum())
@classmethod
def generate(
cls,
root,
*,
year,
num_samples,
):
annotations_dir = root / "annotations"
annotations_dir.mkdir()
for split in ("train", "val"):
config_name = f"{split}{year}"
images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples)
cls._make_annotations(
annotations_dir,
config_name,
images_meta=images_meta,
)
make_zip(root, f"annotations_trainval{year}.zip", annotations_dir)
return num_samples
@dataset_mocks.register_mock_data_fn
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)
...@@ -866,6 +866,13 @@ def _split_files_or_dirs(root, *files_or_dirs): ...@@ -866,6 +866,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name archive = pathlib.Path(root) / name
if not files_or_dirs:
dir = archive.parent / archive.name.replace("".join(archive.suffixes), "")
if dir.exists() and dir.is_dir():
files_or_dirs = (dir,)
else:
raise ValueError("No file or dir provided.")
files, dirs = _split_files_or_dirs(root, *files_or_dirs) files, dirs = _split_files_or_dirs(root, *files_or_dirs)
with opener(archive) as fh: with opener(archive) as fh:
......
...@@ -13,6 +13,17 @@ def to_bytes(file): ...@@ -13,6 +13,17 @@ def to_bytes(file):
return file.read() return file.read()
def config_id(name, config):
parts = [name]
for name, value in config.items():
if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def dataset_parametrization(*names, decoder=to_bytes): def dataset_parametrization(*names, decoder=to_bytes):
if not names: if not names:
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported # TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
...@@ -27,16 +38,17 @@ def dataset_parametrization(*names, decoder=to_bytes): ...@@ -27,16 +38,17 @@ def dataset_parametrization(*names, decoder=to_bytes):
"caltech256", "caltech256",
"caltech101", "caltech101",
"imagenet", "imagenet",
"coco",
) )
params = [] return pytest.mark.parametrize(
for name in names: ("dataset", "mock_info"),
for config in datasets.info(name)._configs: [
id = f"{name}-{'-'.join([str(value) for value in config.values()])}" pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config))
dataset, mock_info = builtin_dataset_mocks.load(name, decoder=decoder, **config) for name in names
params.append(pytest.param(dataset, mock_info, id=id)) for config in datasets.info(name)._configs
],
return pytest.mark.parametrize(("dataset", "mock_info"), params) )
class TestCommon: class TestCommon:
......
__background__,N/A
person,person
bicycle,vehicle
car,vehicle
motorcycle,vehicle
airplane,vehicle
bus,vehicle
train,vehicle
truck,vehicle
boat,vehicle
traffic light,outdoor
fire hydrant,outdoor
N/A,N/A
stop sign,outdoor
parking meter,outdoor
bench,outdoor
bird,animal
cat,animal
dog,animal
horse,animal
sheep,animal
cow,animal
elephant,animal
bear,animal
zebra,animal
giraffe,animal
N/A,N/A
backpack,accessory
umbrella,accessory
N/A,N/A
N/A,N/A
handbag,accessory
tie,accessory
suitcase,accessory
frisbee,sports
skis,sports
snowboard,sports
sports ball,sports
kite,sports
baseball bat,sports
baseball glove,sports
skateboard,sports
surfboard,sports
tennis racket,sports
bottle,kitchen
N/A,N/A
wine glass,kitchen
cup,kitchen
fork,kitchen
knife,kitchen
spoon,kitchen
bowl,kitchen
banana,food
apple,food
sandwich,food
orange,food
broccoli,food
carrot,food
hot dog,food
pizza,food
donut,food
cake,food
chair,furniture
couch,furniture
potted plant,furniture
bed,furniture
N/A,N/A
dining table,furniture
N/A,N/A
N/A,N/A
toilet,furniture
N/A,N/A
tv,electronic
laptop,electronic
mouse,electronic
remote,electronic
keyboard,electronic
cell phone,electronic
microwave,appliance
oven,appliance
toaster,appliance
sink,appliance
refrigerator,appliance
N/A,N/A
book,indoor
clock,indoor
vase,indoor
scissors,indoor
teddy bear,indoor
hair drier,indoor
toothbrush,indoor
import io import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple import re
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -26,24 +28,44 @@ from torchvision.prototype.datasets.utils import ( ...@@ -26,24 +28,44 @@ from torchvision.prototype.datasets.utils import (
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
MappingIterator, MappingIterator,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
getitem, getitem,
path_accessor, path_accessor,
path_comparator,
) )
from torchvision.prototype.features import BoundingBox, Label
from torchvision.prototype.features._feature import DEFAULT
from torchvision.prototype.utils._internal import FrozenMapping
HERE = pathlib.Path(__file__).parent
class CocoLabel(Label):
super_category: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
super_category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), super_category=(super_category, None))
class Coco(Dataset): class Coco(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
name = "coco"
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
return DatasetInfo( return DatasetInfo(
"coco", name,
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
dependencies=("pycocotools",),
categories=categories,
homepage="https://cocodataset.org/", homepage="https://cocodataset.org/",
valid_options=dict( valid_options=dict(
split=("train", "val"), split=("train", "val"),
year=("2017", "2014"), year=("2017", "2014"),
annotations=(*self._ANN_DECODERS.keys(), None),
), ),
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))),
) )
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
...@@ -73,6 +95,62 @@ class Coco(Dataset): ...@@ -73,6 +95,62 @@ class Coco(Dataset):
) )
return [images, meta] return [images, meta]
def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor:
from pycocotools import mask
if is_crowd:
segmentation = mask.frPyObjects(segmentation, *image_size)
else:
segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size))
return torch.from_numpy(mask.decode(segmentation)).to(torch.bool)
def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
image_size = (image_meta["height"], image_meta["width"])
labels = [ann["category_id"] for ann in anns]
categories = [self.info.categories[label] for label in labels]
return dict(
# TODO: create a segmentation feature
segmentations=torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
for ann in anns
]
),
areas=torch.tensor([ann["area"] for ann in anns]),
crowds=torch.tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
image_size=image_size,
),
labels=[
CocoLabel(
label,
category=category,
super_category=self.info.extra.category_to_super_category[category],
)
for label, category in zip(labels, categories)
],
ann_ids=[ann["id"] for ann in anns],
)
def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
return dict(
captions=[ann["caption"] for ann in anns],
ann_ids=[ann["id"] for ann in anns],
)
_ANN_DECODERS = OrderedDict([("instances", _decode_instances_anns), ("captions", _decode_captions_ann)])
_META_FILE_PATTERN = re.compile(
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)
def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations)
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data key, _ = data
if key == "images": if key == "images":
...@@ -82,28 +160,27 @@ class Coco(Dataset): ...@@ -82,28 +160,27 @@ class Coco(Dataset):
else: else:
return None return None
def _decode_ann(self, ann: Dict[str, Any]) -> Dict[str, Any]: def _collate_and_decode_image(
area = torch.tensor(ann["area"]) self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
iscrowd = bool(ann["iscrowd"]) ) -> Dict[str, Any]:
bbox = torch.tensor(ann["bbox"]) path, buffer = data
id = ann["id"] return dict(path=path, image=decoder(buffer) if decoder else buffer)
return dict(area=area, iscrowd=iscrowd, bbox=bbox, id=id)
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]],
*, *,
annotations: Optional[str],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
ann_data, image_data = data ann_data, image_data = data
anns, image_meta = ann_data anns, image_meta = ann_data
path, buffer = image_data
anns = [self._decode_ann(ann) for ann in anns]
image = decoder(buffer) if decoder else buffer sample = self._collate_and_decode_image(image_data, decoder=decoder)
if annotations:
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return dict(anns=anns, id=image_meta["id"], path=path, image=image) return sample
def _make_datapipe( def _make_datapipe(
self, self,
...@@ -114,8 +191,18 @@ class Coco(Dataset): ...@@ -114,8 +191,18 @@ class Coco(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps images_dp, meta_dp = resource_dps
images_dp = ZipArchiveReader(images_dp)
if config.annotations is None:
dp = Shuffler(images_dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = ZipArchiveReader(meta_dp) meta_dp = ZipArchiveReader(meta_dp)
meta_dp = Filter(meta_dp, path_comparator("name", f"instances_{config.split}{config.year}.json")) meta_dp = Filter(
meta_dp,
self._filter_meta_files,
fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations),
)
meta_dp = JsonParser(meta_dp) meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = Mapper(meta_dp, getitem(1))
meta_dp = MappingIterator(meta_dp) meta_dp = MappingIterator(meta_dp)
...@@ -129,24 +216,20 @@ class Coco(Dataset): ...@@ -129,24 +216,20 @@ class Coco(Dataset):
images_meta_dp = Mapper(images_meta_dp, getitem(1)) images_meta_dp = Mapper(images_meta_dp, getitem(1))
images_meta_dp = UnBatcher(images_meta_dp) images_meta_dp = UnBatcher(images_meta_dp)
images_meta_dp = Shuffler(images_meta_dp)
anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp) anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
# drop images without annotations
anns_dp = Filter(anns_dp, bool)
anns_dp = Shuffler(anns_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = IterKeyZipper( anns_dp = IterKeyZipper(
anns_dp, anns_meta_dp,
images_meta_dp, images_meta_dp,
key_fn=getitem(0, "image_id"), key_fn=getitem(0, "image_id"),
ref_key_fn=getitem("id"), ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
images_dp = ZipArchiveReader(images_dp)
dp = IterKeyZipper( dp = IterKeyZipper(
anns_dp, anns_dp,
images_dp, images_dp,
...@@ -154,4 +237,35 @@ class Coco(Dataset): ...@@ -154,4 +237,35 @@ class Coco(Dataset):
ref_key_fn=path_accessor("name"), ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder)
)
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
config = self.default_config
resources = self.resources(config)
dp = resources[1].to_datapipe(pathlib.Path(root) / self.name)
dp = ZipArchiveReader(dp)
dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
)
dp = JsonParser(dp)
_, meta = next(iter(dp))
# List[Tuple[super_category, id, category]]
label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]]
# COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the
# full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total
# number of categories is 90 rather than 91.
_, ids, _ = zip(*label_data)
missing_ids = set(range(1, max(ids) + 1)) - set(ids)
label_data.extend([("N/A", id, "N/A") for id in missing_ids])
# We also add a background category to be used during segmentation.
label_data.append(("N/A", 0, "__background__"))
super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1]))
return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories)))
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import argparse import argparse
import collections.abc import collections.abc
import contextlib import contextlib
import copy
import inspect import inspect
import itertools import itertools
import os import os
...@@ -20,6 +19,7 @@ import torch ...@@ -20,6 +19,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2 from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets from torchvision import datasets as legacy_datasets
from torchvision.datasets.utils import extract_archive
from torchvision.prototype import datasets as new_datasets from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor from torchvision.transforms import PILToTensor
...@@ -27,6 +27,7 @@ from torchvision.transforms import PILToTensor ...@@ -27,6 +27,7 @@ from torchvision.transforms import PILToTensor
def main( def main(
name, name,
*, *,
variant=None,
legacy=True, legacy=True,
new=True, new=True,
start=True, start=True,
...@@ -36,46 +37,57 @@ def main( ...@@ -36,46 +37,57 @@ def main(
temp_root=None, temp_root=None,
num_workers=0, num_workers=0,
): ):
for benchmark in DATASET_BENCHMARKS: benchmarks = [
if benchmark.name == name: benchmark
break for benchmark in DATASET_BENCHMARKS
else: if benchmark.name == name and (variant is None or benchmark.variant == variant)
raise ValueError(f"No DatasetBenchmark available for dataset '{name}'") ]
if not benchmarks:
if legacy and start: msg = f"No DatasetBenchmark available for dataset '{name}'"
print( if variant is not None:
"legacy", msg += f" and variant '{variant}'"
"cold_start", raise ValueError(msg)
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
) for benchmark in benchmarks:
print( print("#" * 80)
"legacy", print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else ""))
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts), if legacy and start:
) print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
)
if legacy and iteration: if legacy and iteration:
print( print(
"legacy", "legacy",
"iteration", "iteration",
Measurement.iterations_per_time( Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples) benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
), ),
) )
if new and start: if new and start:
print( print(
"new", "new",
"cold_start", "cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts), Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
) )
if new and iteration: if new and iteration:
print( print(
"new", "new",
"iteration", "iteration",
Measurement.iterations_per_time(benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)), Measurement.iterations_per_time(
) benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)
),
)
class DatasetBenchmark: class DatasetBenchmark:
...@@ -83,6 +95,7 @@ class DatasetBenchmark: ...@@ -83,6 +95,7 @@ class DatasetBenchmark:
self, self,
name: str, name: str,
*, *,
variant=None,
legacy_cls=None, legacy_cls=None,
new_config=None, new_config=None,
legacy_config_map=None, legacy_config_map=None,
...@@ -90,6 +103,7 @@ class DatasetBenchmark: ...@@ -90,6 +103,7 @@ class DatasetBenchmark:
prepare_legacy_root=None, prepare_legacy_root=None,
): ):
self.name = name self.name = name
self.variant = variant
self.new_raw_dataset = new_datasets._api.find(name) self.new_raw_dataset = new_datasets._api.find(name)
self.legacy_cls = legacy_cls or self._find_legacy_cls() self.legacy_cls = legacy_cls or self._find_legacy_cls()
...@@ -97,14 +111,11 @@ class DatasetBenchmark: ...@@ -97,14 +111,11 @@ class DatasetBenchmark:
if new_config is None: if new_config is None:
new_config = self.new_raw_dataset.default_config new_config = self.new_raw_dataset.default_config
elif isinstance(new_config, dict): elif isinstance(new_config, dict):
new_config = new_datasets.utils.DatasetConfig(new_config) new_config = self.new_raw_dataset.info.make_config(**new_config)
self.new_config = new_config self.new_config = new_config
self.legacy_config = (legacy_config_map or dict)(copy.copy(new_config))
self.legacy_special_options = (legacy_special_options_map or self._legacy_special_options_map)(
copy.copy(new_config)
)
self.legacy_config_map = legacy_config_map
self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map
self.prepare_legacy_root = prepare_legacy_root self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self, *, num_workers=0): def new_dataset(self, *, num_workers=0):
...@@ -142,12 +153,15 @@ class DatasetBenchmark: ...@@ -142,12 +153,15 @@ class DatasetBenchmark:
return context_manager() return context_manager()
def legacy_dataset(self, root, *, num_workers=0, download=None): def legacy_dataset(self, root, *, num_workers=0, download=None):
special_options = self.legacy_special_options.copy() legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict()
special_options = self.legacy_special_options_map(self)
if "download" in special_options and download is not None: if "download" in special_options and download is not None:
special_options["download"] = download special_options["download"] = download
with self.suppress_output(): with self.suppress_output():
return DataLoader( return DataLoader(
self.legacy_cls(str(root), **self.legacy_config, **special_options), self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options),
shuffle=True, shuffle=True,
num_workers=num_workers, num_workers=num_workers,
) )
...@@ -260,16 +274,17 @@ class DatasetBenchmark: ...@@ -260,16 +274,17 @@ class DatasetBenchmark:
"download", "download",
} }
def _legacy_special_options_map(self, config): @staticmethod
def _legacy_special_options_map(benchmark):
available_parameters = set() available_parameters = set()
for cls in self.legacy_cls.__mro__: for cls in benchmark.legacy_cls.__mro__:
if cls is legacy_datasets.VisionDataset: if cls is legacy_datasets.VisionDataset:
break break
available_parameters.update(inspect.signature(cls.__init__).parameters) available_parameters.update(inspect.signature(cls.__init__).parameters)
available_special_kwargs = self._SPECIAL_KWARGS.intersection(available_parameters) available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters)
special_options = dict() special_options = dict()
...@@ -345,15 +360,15 @@ class Measurement: ...@@ -345,15 +360,15 @@ class Measurement:
return mean, std return mean, std
def no_split(config): def no_split(benchmark, root):
legacy_config = dict(config) legacy_config = dict(benchmark.new_config)
del legacy_config["split"] del legacy_config["split"]
return legacy_config return legacy_config
def bool_split(name="train"): def bool_split(name="train"):
def legacy_config_map(config): def legacy_config_map(benchmark, root):
legacy_config = dict(config) legacy_config = dict(benchmark.new_config)
legacy_config[name] = legacy_config.pop("split") == "train" legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config return legacy_config
...@@ -400,8 +415,8 @@ class JointTransform: ...@@ -400,8 +415,8 @@ class JointTransform:
return tuple(transform(input) for transform, input in zip(self.transforms, inputs)) return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(config): def caltech101_legacy_config_map(benchmark, root):
legacy_config = no_split(config) legacy_config = no_split(benchmark, root)
# The new dataset always returns the category and annotation # The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation") legacy_config["target_type"] = ("category", "annotation")
return legacy_config return legacy_config
...@@ -410,8 +425,8 @@ def caltech101_legacy_config_map(config): ...@@ -410,8 +425,8 @@ def caltech101_legacy_config_map(config):
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw") mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
def mnist_legacy_config_map(config): def mnist_legacy_config_map(benchmark, root):
return dict(train=config.split == "train") return dict(train=benchmark.new_config.split == "train")
def emnist_prepare_legacy_root(benchmark, root): def emnist_prepare_legacy_root(benchmark, root):
...@@ -420,20 +435,36 @@ def emnist_prepare_legacy_root(benchmark, root): ...@@ -420,20 +435,36 @@ def emnist_prepare_legacy_root(benchmark, root):
return folder return folder
def emnist_legacy_config_map(config): def emnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(config) legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["split"] = config.image_set.replace("_", "").lower() legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower()
return legacy_config return legacy_config
def qmnist_legacy_config_map(config): def qmnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(config) legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["what"] = config.split legacy_config["what"] = benchmark.new_config.split
# The new dataset always returns the full label # The new dataset always returns the full label
legacy_config["compat"] = False legacy_config["compat"] = False
return legacy_config return legacy_config
def coco_legacy_config_map(benchmark, root):
images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config)
return dict(
root=str(root / pathlib.Path(images.file_name).stem),
annFile=str(
root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json"
),
)
def coco_prepare_legacy_root(benchmark, root):
images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config)
extract_archive(str(root / images.file_name))
extract_archive(str(root / annotations.file_name))
DATASET_BENCHMARKS = [ DATASET_BENCHMARKS = [
DatasetBenchmark( DatasetBenchmark(
"caltech101", "caltech101",
...@@ -453,8 +484,8 @@ DATASET_BENCHMARKS = [ ...@@ -453,8 +484,8 @@ DATASET_BENCHMARKS = [
DatasetBenchmark( DatasetBenchmark(
"celeba", "celeba",
prepare_legacy_root=base_folder(), prepare_legacy_root=base_folder(),
legacy_config_map=lambda config: dict( legacy_config_map=lambda benchmark: dict(
split="valid" if config.split == "val" else config.split, split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split,
# The new dataset always returns all annotations # The new dataset always returns all annotations
target_type=("attr", "identity", "bbox", "landmarks"), target_type=("attr", "identity", "bbox", "landmarks"),
), ),
...@@ -495,17 +526,37 @@ DATASET_BENCHMARKS = [ ...@@ -495,17 +526,37 @@ DATASET_BENCHMARKS = [
DatasetBenchmark( DatasetBenchmark(
"sbd", "sbd",
legacy_cls=legacy_datasets.SBDataset, legacy_cls=legacy_datasets.SBDataset,
legacy_config_map=lambda config: dict( legacy_config_map=lambda benchmark: dict(
image_set=config.split, image_set=benchmark.new_config.split,
mode="boundaries" if config.boundaries else "segmentation", mode="boundaries" if benchmark.new_config.boundaries else "segmentation",
), ),
legacy_special_options_map=lambda config: dict( legacy_special_options_map=lambda benchmark: dict(
download=True, download=True,
transforms=JointTransform(PILToTensor(), torch.tensor if config.boundaries else PILToTensor()), transforms=JointTransform(
PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor()
),
), ),
), ),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection), DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet), DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
DatasetBenchmark(
"coco",
variant="instances",
legacy_cls=legacy_datasets.CocoDetection,
new_config=dict(split="train", annotations="instances"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
DatasetBenchmark(
"coco",
variant="captions",
legacy_cls=legacy_datasets.CocoCaptions,
new_config=dict(split="train", annotations="captions"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
] ]
...@@ -517,6 +568,9 @@ def parse_args(argv=None): ...@@ -517,6 +568,9 @@ def parse_args(argv=None):
) )
parser.add_argument("name", help="Name of the dataset to benchmark.") parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked."
)
parser.add_argument( parser.add_argument(
"-n", "-n",
...@@ -591,6 +645,7 @@ if __name__ == "__main__": ...@@ -591,6 +645,7 @@ if __name__ == "__main__":
try: try:
main( main(
args.name, args.name,
variant=args.variant,
legacy=args.legacy, legacy=args.legacy,
new=args.new, new=args.new,
start=args.start, start=args.start,
......
...@@ -118,7 +118,7 @@ class BoundingBox(Feature): ...@@ -118,7 +118,7 @@ class BoundingBox(Feature):
if data.dtype.is_floating_point: if data.dtype.is_floating_point:
w = w.ceil() w = w.ceil()
h = h.ceil() h = h.ceil()
return int(h), int(w) return int(h.max()), int(w.max())
@classmethod @classmethod
def from_parts( def from_parts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment