"vscode:/vscode.git/clone" did not exist on "4ec38d496db69833eb0a6f144ebbd6f751cd3912"
Unverified Commit 58f313bb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype imagenet dataset (#4640)

* add prototype imagenet dataset

* add missing checksums

* fix mypy

* add human readable categories

* cleanup

* sort categories ascending based on wnid

* remove accidentally added file

* cleanup category file generation

* fix mypy
parent 588c698d
......@@ -2,6 +2,7 @@ from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
from .voc import VOC
This diff is collapsed.
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
path_comparator,
Enumerator,
getitem,
read_mat,
FrozenMapping,
)
class ImageNet(Dataset):
@property
def info(self) -> DatasetInfo:
name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
return DatasetInfo(
name,
type=DatasetType.IMAGE,
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
),
)
@property
def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid
@property
def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
images = HttpResource(
"ILSVRC2012_img_train.tar",
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
)
else: # config.split == "val"
images = HttpResource(
"ILSVRC2012_img_val.tar",
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
)
devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
return [images, devkit]
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid]
label = self.categories.index(category)
return (label, category, wnid), data
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
def _val_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0])
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
label_data, image_data = data
_, label = label_data
category = self.categories[label]
wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
ann_data, image_data = data
label, category, wnid = ann_data
path, buffer = image_data
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=torch.tensor(label),
category=category,
wnid=wnid,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, devkit_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
# dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
else:
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = KeyZipper(
devkit_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._collate_val_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids
import argparse
import csv
import sys
import unittest.mock
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The categories file .+? does not exist.", category=UserWarning)
from torchvision.prototype import datasets
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
......@@ -16,21 +11,18 @@ def main(*names, force=False):
root = datasets.home()
for name in names:
file = BUILTIN_DIR / f"{name}.categories"
if file.exists() and not force:
path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force:
continue
dataset = find(name)
try:
with unittest.mock.patch(
"torchvision.prototype.datasets.utils._dataset.DatasetInfo._read_categories_file", return_value=[]
):
categories = dataset._generate_categories(root)
categories = dataset._generate_categories(root)
except NotImplementedError:
continue
with open(file, "w") as fh:
fh.write("\n".join(categories) + "\n")
with open(path, "w", newline="") as file:
csv.writer(file).writerows(categories)
def parse_args(argv=None):
......
import abc
import csv
import enum
import io
import os
import pathlib
import textwrap
import warnings
from collections.abc import Mapping
from typing import (
Any,
Callable,
......@@ -14,8 +11,6 @@ from typing import (
Optional,
Sequence,
Union,
NoReturn,
Iterable,
Tuple,
)
......@@ -26,76 +21,17 @@ from torchvision.prototype.datasets.utils._internal import (
sequence_to_str,
)
from ._internal import FrozenBunch, make_repr
from ._resource import OnlineResource
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class DatasetType(enum.Enum):
RAW = enum.auto()
IMAGE = enum.auto()
class DatasetConfig(Mapping):
def __init__(self, *args, **kwargs):
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __iter__(self):
return iter(self.__dict__["__data__"].keys())
def __len__(self):
return len(self.__dict__["__data__"])
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DatasetConfig):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
class DatasetConfig(FrozenBunch):
pass
class DatasetInfo:
......@@ -109,6 +45,7 @@ class DatasetInfo:
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
......@@ -118,7 +55,8 @@ class DatasetInfo:
elif isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
categories = self._read_categories_file(pathlib.Path(categories).expanduser().resolve())
path = pathlib.Path(categories).expanduser().resolve()
categories, *_ = zip(*self.read_categories_file(path))
self.categories = tuple(categories)
self.citation = citation
......@@ -137,16 +75,12 @@ class DatasetInfo:
)
self._valid_options: Dict[str, Sequence] = valid_options
@staticmethod
def _read_categories_file(path: pathlib.Path) -> List[str]:
if not path.exists() or not path.is_file():
warnings.warn(
f"The categories file {path} does not exist. Continuing without loaded categories.", UserWarning
)
return []
self.extra = FrozenBunch(extra or dict())
with open(path, "r") as file:
return [line.strip() for line in file]
@staticmethod
def read_categories_file(path: pathlib.Path) -> List[List[str]]:
with open(path, "r", newline="") as file:
return [row for row in csv.reader(file)]
@property
def default_config(self) -> DatasetConfig:
......@@ -231,5 +165,5 @@ class Dataset(abc.ABC):
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[str]:
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
import collections.abc
import csv
import difflib
import enum
import gzip
import io
import lzma
import os
import os.path
import pathlib
from typing import Collection, Sequence, Callable, Union, Any, Tuple, TypeVar, Iterator, Dict, Optional
import textwrap
from collections.abc import Mapping
from typing import (
Collection,
Sequence,
Callable,
Union,
Any,
Tuple,
TypeVar,
Iterator,
Dict,
Optional,
NoReturn,
Iterable,
)
import numpy as np
import PIL.Image
......@@ -18,6 +35,10 @@ __all__ = [
"BUILTIN_DIR",
"sequence_to_str",
"add_suggestion",
"make_repr",
"FrozenMapping",
"FrozenBunch",
"create_categories_file",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
......@@ -62,6 +83,82 @@ def add_suggestion(
return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class FrozenMapping(Mapping):
def __init__(self, *args, **kwargs):
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __iter__(self):
return iter(self.__dict__["__data__"].keys())
def __len__(self):
return len(self.__dict__["__data__"])
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self):
return repr(self.__dict__["__data__"])
class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
csv.writer(file, **fmtparams).writerows(categories)
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment