Unverified Commit 58f313bb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype imagenet dataset (#4640)

* add prototype imagenet dataset

* add missing checksums

* fix mypy

* add human readable categories

* cleanup

* sort categories ascending based on wnid

* remove accidentally added file

* cleanup category file generation

* fix mypy
parent 588c698d
...@@ -2,6 +2,7 @@ from .caltech import Caltech101, Caltech256 ...@@ -2,6 +2,7 @@ from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import Cifar10, Cifar100 from .cifar import Cifar10, Cifar100
from .coco import Coco from .coco import Coco
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD from .sbd import SBD
from .voc import VOC from .voc import VOC
This diff is collapsed.
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
path_comparator,
Enumerator,
getitem,
read_mat,
FrozenMapping,
)
class ImageNet(Dataset):
@property
def info(self) -> DatasetInfo:
name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
return DatasetInfo(
name,
type=DatasetType.IMAGE,
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
),
)
@property
def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid
@property
def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
images = HttpResource(
"ILSVRC2012_img_train.tar",
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
)
else: # config.split == "val"
images = HttpResource(
"ILSVRC2012_img_val.tar",
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
)
devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
return [images, devkit]
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid]
label = self.categories.index(category)
return (label, category, wnid), data
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
def _val_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0])
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
label_data, image_data = data
_, label = label_data
category = self.categories[label]
wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
ann_data, image_data = data
label, category, wnid = ann_data
path, buffer = image_data
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=torch.tensor(label),
category=category,
wnid=wnid,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, devkit_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
# dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
else:
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = KeyZipper(
devkit_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._collate_val_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids
import argparse import argparse
import csv
import sys import sys
import unittest.mock
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The categories file .+? does not exist.", category=UserWarning)
from torchvision.prototype import datasets
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
...@@ -16,21 +11,18 @@ def main(*names, force=False): ...@@ -16,21 +11,18 @@ def main(*names, force=False):
root = datasets.home() root = datasets.home()
for name in names: for name in names:
file = BUILTIN_DIR / f"{name}.categories" path = BUILTIN_DIR / f"{name}.categories"
if file.exists() and not force: if path.exists() and not force:
continue continue
dataset = find(name) dataset = find(name)
try: try:
with unittest.mock.patch( categories = dataset._generate_categories(root)
"torchvision.prototype.datasets.utils._dataset.DatasetInfo._read_categories_file", return_value=[]
):
categories = dataset._generate_categories(root)
except NotImplementedError: except NotImplementedError:
continue continue
with open(file, "w") as fh: with open(path, "w", newline="") as file:
fh.write("\n".join(categories) + "\n") csv.writer(file).writerows(categories)
def parse_args(argv=None): def parse_args(argv=None):
......
import abc import abc
import csv
import enum import enum
import io import io
import os
import pathlib import pathlib
import textwrap
import warnings
from collections.abc import Mapping
from typing import ( from typing import (
Any, Any,
Callable, Callable,
...@@ -14,8 +11,6 @@ from typing import ( ...@@ -14,8 +11,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Union, Union,
NoReturn,
Iterable,
Tuple, Tuple,
) )
...@@ -26,76 +21,17 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,76 +21,17 @@ from torchvision.prototype.datasets.utils._internal import (
sequence_to_str, sequence_to_str,
) )
from ._internal import FrozenBunch, make_repr
from ._resource import OnlineResource from ._resource import OnlineResource
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class DatasetType(enum.Enum): class DatasetType(enum.Enum):
RAW = enum.auto() RAW = enum.auto()
IMAGE = enum.auto() IMAGE = enum.auto()
class DatasetConfig(Mapping): class DatasetConfig(FrozenBunch):
def __init__(self, *args, **kwargs): pass
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __iter__(self):
return iter(self.__dict__["__data__"].keys())
def __len__(self):
return len(self.__dict__["__data__"])
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DatasetConfig):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
class DatasetInfo: class DatasetInfo:
...@@ -109,6 +45,7 @@ class DatasetInfo: ...@@ -109,6 +45,7 @@ class DatasetInfo:
homepage: Optional[str] = None, homepage: Optional[str] = None,
license: Optional[str] = None, license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None, valid_options: Optional[Dict[str, Sequence]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.name = name.lower() self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.type = DatasetType[type.upper()] if isinstance(type, str) else type
...@@ -118,7 +55,8 @@ class DatasetInfo: ...@@ -118,7 +55,8 @@ class DatasetInfo:
elif isinstance(categories, int): elif isinstance(categories, int):
categories = [str(label) for label in range(categories)] categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
categories = self._read_categories_file(pathlib.Path(categories).expanduser().resolve()) path = pathlib.Path(categories).expanduser().resolve()
categories, *_ = zip(*self.read_categories_file(path))
self.categories = tuple(categories) self.categories = tuple(categories)
self.citation = citation self.citation = citation
...@@ -137,16 +75,12 @@ class DatasetInfo: ...@@ -137,16 +75,12 @@ class DatasetInfo:
) )
self._valid_options: Dict[str, Sequence] = valid_options self._valid_options: Dict[str, Sequence] = valid_options
@staticmethod self.extra = FrozenBunch(extra or dict())
def _read_categories_file(path: pathlib.Path) -> List[str]:
if not path.exists() or not path.is_file():
warnings.warn(
f"The categories file {path} does not exist. Continuing without loaded categories.", UserWarning
)
return []
with open(path, "r") as file: @staticmethod
return [line.strip() for line in file] def read_categories_file(path: pathlib.Path) -> List[List[str]]:
with open(path, "r", newline="") as file:
return [row for row in csv.reader(file)]
@property @property
def default_config(self) -> DatasetConfig: def default_config(self) -> DatasetConfig:
...@@ -231,5 +165,5 @@ class Dataset(abc.ABC): ...@@ -231,5 +165,5 @@ class Dataset(abc.ABC):
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[str]: def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError raise NotImplementedError
import collections.abc import collections.abc
import csv
import difflib import difflib
import enum import enum
import gzip import gzip
import io import io
import lzma import lzma
import os
import os.path import os.path
import pathlib import pathlib
from typing import Collection, Sequence, Callable, Union, Any, Tuple, TypeVar, Iterator, Dict, Optional import textwrap
from collections.abc import Mapping
from typing import (
Collection,
Sequence,
Callable,
Union,
Any,
Tuple,
TypeVar,
Iterator,
Dict,
Optional,
NoReturn,
Iterable,
)
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -18,6 +35,10 @@ __all__ = [ ...@@ -18,6 +35,10 @@ __all__ = [
"BUILTIN_DIR", "BUILTIN_DIR",
"sequence_to_str", "sequence_to_str",
"add_suggestion", "add_suggestion",
"make_repr",
"FrozenMapping",
"FrozenBunch",
"create_categories_file",
"read_mat", "read_mat",
"image_buffer_from_array", "image_buffer_from_array",
"SequenceIterator", "SequenceIterator",
...@@ -62,6 +83,82 @@ def add_suggestion( ...@@ -62,6 +83,82 @@ def add_suggestion(
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class FrozenMapping(Mapping):
def __init__(self, *args, **kwargs):
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __iter__(self):
return iter(self.__dict__["__data__"].keys())
def __len__(self):
return len(self.__dict__["__data__"])
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self):
return repr(self.__dict__["__data__"])
class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
csv.writer(file, **fmtparams).writerows(categories)
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try: try:
import scipy.io as sio import scipy.io as sio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment