Unverified Commit aedd3979 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

return features instead of vanilla tensors from prototype datasets (#4864)

* return features instead of vanilla tensors from prototype datasets

* fix tests

* remove inplace

* add explanation for __init_subclass__

* fix label for test split

* relax test

* remove pixels
parent 775129be
...@@ -14,6 +14,7 @@ from datasets_utils import create_image_folder, make_tar, make_zip ...@@ -14,6 +14,7 @@ from datasets_utils import create_image_folder, make_tar, make_zip
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import add_suggestion from torchvision.prototype.utils._internal import add_suggestion
...@@ -99,20 +100,8 @@ class DatasetMocks: ...@@ -99,20 +100,8 @@ class DatasetMocks:
self._cache[(name, config)] = mock_resources, mock_info self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info return mock_resources, mock_info
def _decoder(self, dataset_type):
def to_bytes(file):
try:
return file.read()
finally:
file.close()
if dataset_type == datasets.utils.DatasetType.RAW:
return datasets.decoder.raw
else:
return to_bytes
def load( def load(
self, name: str, decoder=DEFAULT_TEST_DECODER, split="train", **options: Any self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any
) -> Tuple[IterDataPipe, Dict[str, Any]]: ) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name) dataset = find(name)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
...@@ -120,7 +109,7 @@ class DatasetMocks: ...@@ -120,7 +109,7 @@ class DatasetMocks:
datapipe = dataset._make_datapipe( datapipe = dataset._make_datapipe(
[resource.to_datapipe() for resource in resources], [resource.to_datapipe() for resource in resources],
config=config, config=config,
decoder=self._decoder(dataset.info.type) if decoder is DEFAULT_TEST_DECODER else decoder, decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
) )
return datapipe, mock_info return datapipe, mock_info
......
import functools
import io import io
import builtin_dataset_mocks import builtin_dataset_mocks
import pytest import pytest
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets from torchvision.prototype import datasets, features
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
_loaders = [] def to_bytes(file):
_datasets = [] try:
return file.read()
finally:
file.close()
# TODO: this can be replaced by torchvision.prototype.datasets.list() as soon as all builtin datasets are supported def dataset_parametrization(*names, decoder=to_bytes):
TMP = [ if not names:
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
names = (
"mnist", "mnist",
"fashionmnist", "fashionmnist",
"kmnist", "kmnist",
...@@ -23,30 +29,27 @@ TMP = [ ...@@ -23,30 +29,27 @@ TMP = [
"caltech256", "caltech256",
"caltech101", "caltech101",
"imagenet", "imagenet",
]
for name in TMP:
loader = functools.partial(builtin_dataset_mocks.load, name)
_loaders.append(pytest.param(loader, id=name))
info = datasets.info(name)
_datasets.extend(
[
pytest.param(*loader(**config), id=f"{name}-{'-'.join([str(value) for value in config.values()])}")
for config in info._configs
]
) )
loaders = pytest.mark.parametrize("loader", _loaders) params = []
builtin_datasets = pytest.mark.parametrize(("dataset", "mock_info"), _datasets) for name in names:
for config in datasets.info(name)._configs:
if name == "imagenet" and config.split == "test":
print()
id = f"{name}-{'-'.join([str(value) for value in config.values()])}"
dataset, mock_info = builtin_dataset_mocks.load(name, decoder=decoder, **config)
params.append(pytest.param(dataset, mock_info, id=id))
return pytest.mark.parametrize(("dataset", "mock_info"), params)
class TestCommon: class TestCommon:
@builtin_datasets @dataset_parametrization()
def test_smoke(self, dataset, mock_info): def test_smoke(self, dataset, mock_info):
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@builtin_datasets @dataset_parametrization()
def test_sample(self, dataset, mock_info): def test_sample(self, dataset, mock_info):
try: try:
sample = next(iter(dataset)) sample = next(iter(dataset))
...@@ -59,7 +62,7 @@ class TestCommon: ...@@ -59,7 +62,7 @@ class TestCommon:
if not sample: if not sample:
raise AssertionError("Sample dictionary is empty.") raise AssertionError("Sample dictionary is empty.")
@builtin_datasets @dataset_parametrization()
def test_num_samples(self, dataset, mock_info): def test_num_samples(self, dataset, mock_info):
num_samples = 0 num_samples = 0
for _ in dataset: for _ in dataset:
...@@ -67,7 +70,7 @@ class TestCommon: ...@@ -67,7 +70,7 @@ class TestCommon:
assert num_samples == mock_info["num_samples"] assert num_samples == mock_info["num_samples"]
@builtin_datasets @dataset_parametrization()
def test_decoding(self, dataset, mock_info): def test_decoding(self, dataset, mock_info):
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features: if undecoded_features:
...@@ -76,6 +79,12 @@ class TestCommon: ...@@ -76,6 +79,12 @@ class TestCommon:
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
) )
@dataset_parametrization(decoder=DEFAULT_DECODER)
def test_at_least_one_feature(self, dataset, mock_info):
sample = next(iter(dataset))
if not any(isinstance(value, features.Feature) for value in sample.values()):
raise AssertionError("The sample contained no feature.")
class TestQMNIST: class TestQMNIST:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -49,9 +49,9 @@ def info(name: str) -> DatasetInfo: ...@@ -49,9 +49,9 @@ def info(name: str) -> DatasetInfo:
return find(name).info return find(name).info
default = object() DEFAULT_DECODER = object()
DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw, DatasetType.RAW: raw,
DatasetType.IMAGE: pil, DatasetType.IMAGE: pil,
} }
...@@ -60,15 +60,15 @@ DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { ...@@ -60,15 +60,15 @@ DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
def load( def load(
name: str, name: str,
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment] decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
split: str = "train", split: str = "train",
**options: Any, **options: Any,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
name = name.lower() name = name.lower()
dataset = find(name) dataset = find(name)
if decoder is default: if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER.get(dataset.info.type) decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
root = os.path.join(home(), name) root = os.path.join(home(), name)
......
...@@ -22,6 +22,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -22,6 +22,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
from torchvision.prototype.features import Label, BoundingBox
class Caltech101(Dataset): class Caltech101(Dataset):
...@@ -95,8 +96,8 @@ class Caltech101(Dataset): ...@@ -95,8 +96,8 @@ class Caltech101(Dataset):
image = decoder(image_buffer) if decoder else image_buffer image = decoder(image_buffer) if decoder else image_buffer
ann = read_mat(ann_buffer) ann = read_mat(ann_buffer)
bbox = torch.as_tensor(ann["box_coord"].astype(np.int64)) bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
contour = torch.as_tensor(ann["obj_contour"]) contour = torch.tensor(ann["obj_contour"].T)
return dict( return dict(
category=category, category=category,
...@@ -171,9 +172,9 @@ class Caltech256(Dataset): ...@@ -171,9 +172,9 @@ class Caltech256(Dataset):
dir_name = pathlib.Path(path).parent.name dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".") label_str, category = dir_name.split(".")
label = torch.tensor(int(label_str)) label = Label(int(label_str), category=category)
return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer) return dict(label=label, image=decoder(buffer) if decoder else buffer)
def _make_datapipe( def _make_datapipe(
self, self,
......
...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
image_buffer_from_array, image_buffer_from_array,
path_comparator, path_comparator,
) )
from torchvision.prototype.features import Label, Image
__all__ = ["Cifar10", "Cifar100"] __all__ = ["Cifar10", "Cifar100"]
...@@ -65,17 +66,16 @@ class _CifarBase(Dataset): ...@@ -65,17 +66,16 @@ class _CifarBase(Dataset):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
image_array, category_idx = data image_array, category_idx = data
category = self.categories[category_idx] image: Union[Image, io.BytesIO]
label = torch.tensor(category_idx)
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw: if decoder is raw:
image = torch.from_numpy(image_array) image = Image(image_array)
else: else:
image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0))) image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0)))
image = decoder(image_buffer) if decoder else image_buffer image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = Label(category_idx, category=self.categories[category_idx])
return dict(label=label, category=category, image=image) return dict(image=image, label=label)
def _make_datapipe( def _make_datapipe(
self, self,
......
...@@ -21,9 +21,22 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -21,9 +21,22 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
read_mat, read_mat,
) )
from torchvision.prototype.features import Label, DEFAULT
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
class ImageNetLabel(Label):
wnid: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
wnid: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), wnid=(wnid, None))
class ImageNet(Dataset): class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
name = "imagenet" name = "imagenet"
...@@ -78,12 +91,12 @@ class ImageNet(Dataset): ...@@ -78,12 +91,12 @@ class ImageNet(Dataset):
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG") _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]: def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid] category = self.wnid_to_category[wnid]
label = self.categories.index(category) label = ImageNetLabel(self.categories.index(category), category=category, wnid=wnid)
return (label, category, wnid), data return label, data
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG") _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
...@@ -93,31 +106,27 @@ class ImageNet(Dataset): ...@@ -93,31 +106,27 @@ class ImageNet(Dataset):
def _collate_val_data( def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]: ) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
label_data, image_data = data label_data, image_data = data
_, label = label_data _, label = label_data
category = self.categories[label] category = self.categories[label]
wnid = self.category_to_wnid[category] wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data return ImageNetLabel(label, category=category, wnid=wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]: def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]:
return (None, None, None), data return None, data
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]], data: Tuple[Optional[ImageNetLabel], Tuple[str, io.IOBase]],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
ann_data, image_data = data label, (path, buffer) = data
label, category, wnid = ann_data
path, buffer = image_data
return dict( return dict(
path=path, path=path,
image=decoder(buffer) if decoder else buffer, image=decoder(buffer) if decoder else buffer,
label=label, label=label,
category=category,
wnid=wnid,
) )
def _make_datapipe( def _make_datapipe(
......
...@@ -13,4 +13,7 @@ def raw(buffer: io.IOBase) -> torch.Tensor: ...@@ -13,4 +13,7 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
def pil(buffer: io.IOBase) -> features.Image: def pil(buffer: io.IOBase) -> features.Image:
try:
return features.Image(pil_to_tensor(PIL.Image.open(buffer))) return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
finally:
buffer.close()
from ._bounding_box import BoundingBoxFormat, BoundingBox from ._bounding_box import BoundingBoxFormat, BoundingBox
from ._feature import Feature from ._feature import Feature, DEFAULT
from ._image import Image, ColorSpace from ._image import Image, ColorSpace
from ._label import Label from ._label import Label
...@@ -115,7 +115,10 @@ class BoundingBox(Feature): ...@@ -115,7 +115,10 @@ class BoundingBox(Feature):
data = cls._TO_XYXY_MAP[format](data) data = cls._TO_XYXY_MAP[format](data)
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data) data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
*_, w, h = to_parts(data) *_, w, h = to_parts(data)
return int(h.ceil()), int(w.ceil()) if data.dtype.is_floating_point:
w = w.ceil()
h = h.ceil()
return int(h), int(w)
@classmethod @classmethod
def from_parts( def from_parts(
......
...@@ -16,13 +16,29 @@ class Feature(torch.Tensor): ...@@ -16,13 +16,29 @@ class Feature(torch.Tensor):
_meta_data: Dict[str, Any] _meta_data: Dict[str, Any]
def __init_subclass__(cls): def __init_subclass__(cls):
if not hasattr(cls, "_META_ATTRS"): # In order to help static type checkers, we require subclasses of `Feature` add the meta data attributes
cls._META_ATTRS = { # as static class annotations:
attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_") #
} # >>> class Foo(Feature):
# ... bar: str
for attr in cls._META_ATTRS: # ... baz: Optional[str]
if not hasattr(cls, attr): #
# Internally, this information is used twofold:
#
# 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference
# to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime
# access. This happens in this method.
# 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for
# unknown arguments.
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is Feature:
break
meta_attrs.update(super_cls._META_ATTRS)
cls._META_ATTRS = meta_attrs
for attr in meta_attrs:
setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr])) setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr]))
def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs):
...@@ -33,7 +49,7 @@ class Feature(torch.Tensor): ...@@ -33,7 +49,7 @@ class Feature(torch.Tensor):
add_suggestion( add_suggestion(
f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.", f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.",
word=unknown_meta_attr, word=unknown_meta_attr,
possibilities=cls._META_ATTRS, possibilities=sorted(cls._META_ATTRS),
) )
) )
......
...@@ -29,7 +29,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: ...@@ -29,7 +29,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1: if len(seq) == 1:
return f"'{seq[0]}'" return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'.""" return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'"""
def add_suggestion( def add_suggestion(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment