Unverified Commit d32bc4ba authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Revamp prototype features and transforms (#5407)

* revamp prototype features (#5283)

* remove decoding from prototype datasets (#5287)

* remove decoder from prototype datasets

* remove unused imports

* cleanup

* fix readme

* use OneHotLabel in SEMEION

* improve voc implementation

* revert unrelated changes

* fix semeion mock data

* fix pcam

* readd functional transforms API to prototype (#5295)

* readd functional transforms

* cleanup

* add missing imports

* remove __torch_function__ dispatch

* readd repr

* readd empty line

* add test for scriptability

* remove function copy

* change import from functional tensor transforms to just functional

* fix import

* fix test

* fix prototype features and functional transforms after review (#5377)

* fix prototype functional transforms after review

* address features review

* make mypy more strict on prototype features

* make mypy more strict for prototype transforms

* fix annotation

* fix kernel tests

* add automatic feature type dispatch to functional transforms (#5323)

* add auto dispatch

* fix missing arguments error message

* remove pil kernel for erase

* automate feature specific parameter detection

* fix typos

* cleanup dispatcher call

* remove __torch_function__ from transform dispatch

* remove auto-generation

* revert unrelated changes

* remove implements decorator

* change register parameter order

* change order of transforms for readability

* add documentation for __torch_function__

* fix mypy

* inline check for support

* refactor kernel registering process

* refactor dispatch to be a regular decorator

* split kernels and dispatchers

* remove sentinels

* replace pass with ...

* appease mypy

* make single kernel dispatchers more concise

* make dispatcher signatures more generic

* make kernel checking more strict

* revert doc changes

* address Franciscos comments

* remove inplace

* rename kernel test module

* fix inplace

* remove special casing for pil and vanilla tensors

* address comments

* update docs

* cleanup features / transforms feature branch (#5406)

* mark candidates for removal

* align signature of resize_bounding_box with corresponding image kernel

* fix documentation of Feature

* remove interpolation mode and antialias option from resize_segmentation_mask

* remove or privatize functionality in features / datasets / transforms
parent f2f490b1
import abc import abc
import functools import functools
import io
import operator import operator
import pathlib import pathlib
import string import string
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -13,24 +12,21 @@ from torchdata.datapipes.iter import ( ...@@ -13,24 +12,21 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
Zipper, Zipper,
) )
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetType,
DatasetConfig, DatasetConfig,
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
image_buffer_from_array,
Decompressor, Decompressor,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
fromfile,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, Label
from torchvision.prototype.utils._internal import fromfile
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
...@@ -105,31 +101,15 @@ class _MNISTBase(Dataset): ...@@ -105,31 +101,15 @@ class _MNISTBase(Dataset):
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]:
return None, None return None, None
def _collate_and_decode( def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image, label = data image, label = data
return dict(
if decoder is raw: image=Image(image),
image = Image(image) label=Label(label, dtype=torch.int64, categories=self.categories),
else: )
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)])
return dict(image=image, label=label)
def _make_datapipe( def _make_datapipe(
self, self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, labels_dp = resource_dps images_dp, labels_dp = resource_dps
start, stop = self.start_and_stop(config) start, stop = self.start_and_stop(config)
...@@ -143,14 +123,13 @@ class _MNISTBase(Dataset): ...@@ -143,14 +123,13 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder)) return Mapper(dp, functools.partial(self._prepare_sample, config=config))
class MNIST(_MNISTBase): class MNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"mnist", "mnist",
type=DatasetType.RAW,
categories=10, categories=10,
homepage="http://yann.lecun.com/exdb/mnist", homepage="http://yann.lecun.com/exdb/mnist",
valid_options=dict( valid_options=dict(
...@@ -183,7 +162,6 @@ class FashionMNIST(MNIST): ...@@ -183,7 +162,6 @@ class FashionMNIST(MNIST):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"fashionmnist", "fashionmnist",
type=DatasetType.RAW,
categories=( categories=(
"T-shirt/top", "T-shirt/top",
"Trouser", "Trouser",
...@@ -215,7 +193,6 @@ class KMNIST(MNIST): ...@@ -215,7 +193,6 @@ class KMNIST(MNIST):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"kmnist", "kmnist",
type=DatasetType.RAW,
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
homepage="http://codh.rois.ac.jp/kmnist/index.html.en", homepage="http://codh.rois.ac.jp/kmnist/index.html.en",
valid_options=dict( valid_options=dict(
...@@ -236,7 +213,6 @@ class EMNIST(_MNISTBase): ...@@ -236,7 +213,6 @@ class EMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"emnist", "emnist",
type=DatasetType.RAW,
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist",
valid_options=dict( valid_options=dict(
...@@ -291,13 +267,7 @@ class EMNIST(_MNISTBase): ...@@ -291,13 +267,7 @@ class EMNIST(_MNISTBase):
46: 9, 46: 9,
} }
def _collate_and_decode( def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
...@@ -310,14 +280,10 @@ class EMNIST(_MNISTBase): ...@@ -310,14 +280,10 @@ class EMNIST(_MNISTBase):
image, label = data image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0) label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label) data = (image, label)
return super()._collate_and_decode(data, config=config, decoder=decoder) return super()._prepare_sample(data, config=config)
def _make_datapipe( def _make_datapipe(
self, self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
images_dp, labels_dp = Demultiplexer( images_dp, labels_dp = Demultiplexer(
...@@ -327,14 +293,13 @@ class EMNIST(_MNISTBase): ...@@ -327,14 +293,13 @@ class EMNIST(_MNISTBase):
drop_none=True, drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) return super()._make_datapipe([images_dp, labels_dp], config=config)
class QMNIST(_MNISTBase): class QMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"qmnist", "qmnist",
type=DatasetType.RAW,
categories=10, categories=10,
homepage="https://github.com/facebookresearch/qmnist", homepage="https://github.com/facebookresearch/qmnist",
valid_options=dict( valid_options=dict(
...@@ -376,16 +341,10 @@ class QMNIST(_MNISTBase): ...@@ -376,16 +341,10 @@ class QMNIST(_MNISTBase):
return start, stop return start, stop
def _collate_and_decode( def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image, ann = data image, ann = data
label, *extra_anns = ann label, *extra_anns = ann
sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) sample = super()._prepare_sample((image, label), config=config)
sample.update( sample.update(
dict( dict(
......
import enum import enum
import functools
import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, BinaryIO
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
...@@ -12,7 +9,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -12,7 +9,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
...@@ -22,7 +18,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -22,7 +18,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
path_comparator, path_comparator,
) )
from torchvision.prototype.features import Label from torchvision.prototype.features import Label, EncodedImage
class OxfordIITPetDemux(enum.IntEnum): class OxfordIITPetDemux(enum.IntEnum):
...@@ -34,7 +30,6 @@ class OxfordIITPet(Dataset): ...@@ -34,7 +30,6 @@ class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"oxford-iiit-pet", "oxford-iiit-pet",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict( valid_options=dict(
split=("trainval", "test"), split=("trainval", "test"),
...@@ -66,18 +61,8 @@ class OxfordIITPet(Dataset): ...@@ -66,18 +61,8 @@ class OxfordIITPet(Dataset):
def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: def _filter_segmentations(self, data: Tuple[str, Any]) -> bool:
return not pathlib.Path(data[0]).name.startswith(".") return not pathlib.Path(data[0]).name.startswith(".")
def _decode_classification_data(self, data: Dict[str, str]) -> Dict[str, Any]: def _prepare_sample(
label_idx = int(data["label"]) - 1 self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]
return dict(
label=Label(label_idx, category=self.info.categories[label_idx]),
species="cat" if data["species"] == "1" else "dog",
)
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[Dict[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
ann_data, image_data = data ann_data, image_data = data
classification_data, segmentation_data = ann_data classification_data, segmentation_data = ann_data
...@@ -85,19 +70,16 @@ class OxfordIITPet(Dataset): ...@@ -85,19 +70,16 @@ class OxfordIITPet(Dataset):
image_path, image_buffer = image_data image_path, image_buffer = image_data
return dict( return dict(
self._decode_classification_data(classification_data), label=Label(int(classification_data["label"]) - 1, categories=self.categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path, segmentation_path=segmentation_path,
segmentation=decoder(segmentation_buffer) if decoder else segmentation_buffer, segmentation=EncodedImage.from_file(segmentation_buffer),
image_path=image_path, image_path=image_path,
image=decoder(image_buffer) if decoder else image_buffer, image=EncodedImage.from_file(image_buffer),
) )
def _make_datapipe( def _make_datapipe(
self, self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
...@@ -137,7 +119,7 @@ class OxfordIITPet(Dataset): ...@@ -137,7 +119,7 @@ class OxfordIITPet(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) return Mapper(dp, self._prepare_sample)
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
......
import io import io
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator from typing import Any, Dict, List, Optional, Tuple, Iterator
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
...@@ -10,7 +9,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -10,7 +9,6 @@ from torchvision.prototype.datasets.utils import (
DatasetConfig, DatasetConfig,
DatasetInfo, DatasetInfo,
OnlineResource, OnlineResource,
DatasetType,
GDriveResource, GDriveResource,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
...@@ -46,7 +44,6 @@ class PCAM(Dataset): ...@@ -46,7 +44,6 @@ class PCAM(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"pcam", "pcam",
type=DatasetType.RAW,
homepage="https://github.com/basveeling/pcam", homepage="https://github.com/basveeling/pcam",
categories=2, categories=2,
valid_options=dict(split=("train", "test", "val")), valid_options=dict(split=("train", "test", "val")),
...@@ -98,7 +95,7 @@ class PCAM(Dataset): ...@@ -98,7 +95,7 @@ class PCAM(Dataset):
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
] ]
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
image, target = data # They're both numpy arrays at this point image, target = data # They're both numpy arrays at this point
return { return {
...@@ -107,11 +104,7 @@ class PCAM(Dataset): ...@@ -107,11 +104,7 @@ class PCAM(Dataset):
} }
def _make_datapipe( def _make_datapipe(
self, self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps images_dp, targets_dp = resource_dps
...@@ -122,4 +115,4 @@ class PCAM(Dataset): ...@@ -122,4 +115,4 @@ class PCAM(Dataset):
dp = Zipper(images_dp, targets_dp) dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode) return Mapper(dp, self._prepare_sample)
import functools
import io
import pathlib import pathlib
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple, cast from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
import numpy as np import numpy as np
import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
...@@ -20,7 +17,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -20,7 +17,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
...@@ -31,20 +27,17 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -31,20 +27,17 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import Feature from torchvision.prototype.features import _Feature, EncodedImage
class SBD(Dataset): class SBD(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"sbd", "sbd",
type=DatasetType.IMAGE,
dependencies=("scipy",), dependencies=("scipy",),
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict( valid_options=dict(
split=("train", "val", "train_noval"), split=("train", "val", "train_noval"),
boundaries=(True, False),
segmentation=(False, True),
), ),
) )
...@@ -75,50 +68,21 @@ class SBD(Dataset): ...@@ -75,50 +68,21 @@ class SBD(Dataset):
else: else:
return None return None
def _decode_ann( def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
raw_anns = data["GTcls"][0]
raw_boundaries = raw_anns["Boundaries"][0]
raw_segmentation = raw_anns["Segmentation"][0]
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries = (
Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
if decode_boundaries
else None
)
segmentation = Feature(raw_segmentation) if decode_segmentation else None
return boundaries, segmentation
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
_, image_data = split_and_image_data _, image_data = split_and_image_data
image_path, image_buffer = image_data image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data ann_path, ann_buffer = ann_data
image = decoder(image_buffer) if decoder else image_buffer anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]
if config.boundaries or config.segmentation:
boundaries, segmentation = self._decode_ann(
read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation
)
else:
boundaries = segmentation = None
return dict( return dict(
image_path=image_path, image_path=image_path,
image=image, image=EncodedImage.from_file(image_buffer),
ann_path=ann_path, ann_path=ann_path,
boundaries=boundaries, # the boundaries are stored in sparse CSC format, which is not supported by PyTorch
segmentation=segmentation, boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
segmentation=_Feature(anns["Segmentation"].item()),
) )
def _make_datapipe( def _make_datapipe(
...@@ -126,7 +90,6 @@ class SBD(Dataset): ...@@ -126,7 +90,6 @@ class SBD(Dataset):
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps archive_dp, extra_split_dp = resource_dps
...@@ -138,10 +101,10 @@ class SBD(Dataset): ...@@ -138,10 +101,10 @@ class SBD(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True, drop_none=True,
) )
if config.split == "train_noval": if config.split == "train_noval":
split_dp = extra_split_dp split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("stem", config.split))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
...@@ -155,7 +118,7 @@ class SBD(Dataset): ...@@ -155,7 +118,7 @@ class SBD(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
resources = self.resources(self.default_config) resources = self.resources(self.default_config)
......
import functools from typing import Any, Dict, List, Tuple
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -8,24 +6,21 @@ from torchdata.datapipes.iter import ( ...@@ -8,24 +6,21 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
CSVParser, CSVParser,
) )
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig, DatasetConfig,
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, OneHotLabel
class SEMEION(Dataset): class SEMEION(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"semeion", "semeion",
type=DatasetType.RAW,
categories=10, categories=10,
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
) )
...@@ -37,34 +32,22 @@ class SEMEION(Dataset): ...@@ -37,34 +32,22 @@ class SEMEION(Dataset):
) )
return [data] return [data]
def _collate_and_decode_sample( def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]:
self, image_data, label_data = data[:256], data[256:-1]
data: Tuple[str, ...],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16)
label_data = [int(label) for label in data[256:] if label]
if decoder is raw: return dict(
image = Image(image_data.unsqueeze(0)) image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)),
else: label=OneHotLabel([int(label) for label in label_data], categories=self.categories),
image_buffer = image_buffer_from_array(image_data.numpy()) )
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx]))
def _make_datapipe( def _make_datapipe(
self, self,
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) return Mapper(dp, self._prepare_sample)
return dp
import functools from typing import Any, Dict, List, Tuple, BinaryIO
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
UnBatcher, UnBatcher,
) )
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig, DatasetConfig,
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
read_mat, read_mat,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
image_buffer_from_array,
) )
from torchvision.prototype.features import Label, Image from torchvision.prototype.features import Label, Image
...@@ -31,7 +25,6 @@ class SVHN(Dataset): ...@@ -31,7 +25,6 @@ class SVHN(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"svhn", "svhn",
type=DatasetType.RAW,
dependencies=("scipy",), dependencies=("scipy",),
categories=10, categories=10,
homepage="http://ufldl.stanford.edu/housenumbers/", homepage="http://ufldl.stanford.edu/housenumbers/",
...@@ -52,7 +45,7 @@ class SVHN(Dataset): ...@@ -52,7 +45,7 @@ class SVHN(Dataset):
return [data] return [data]
def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]: def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]:
_, buffer = data _, buffer = data
content = read_mat(buffer) content = read_mat(buffer)
return list( return list(
...@@ -62,23 +55,12 @@ class SVHN(Dataset): ...@@ -62,23 +55,12 @@ class SVHN(Dataset):
) )
) )
def _collate_and_decode_sample( def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]:
self,
data: Tuple[np.ndarray, np.ndarray],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data image_array, label_array = data
if decoder is raw:
image = Image(image_array.transpose((2, 0, 1)))
else:
image_buffer = image_buffer_from_array(image_array)
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
return dict( return dict(
image=image, image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10), label=Label(int(label_array) % 10, categories=self.categories),
) )
def _make_datapipe( def _make_datapipe(
...@@ -86,11 +68,10 @@ class SVHN(Dataset): ...@@ -86,11 +68,10 @@ class SVHN(Dataset):
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels) dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp) dp = UnBatcher(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) return Mapper(dp, self._prepare_sample)
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import functools import functools
import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable
from xml.etree import ElementTree from xml.etree import ElementTree
import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
...@@ -20,7 +18,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -20,7 +18,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo, DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
...@@ -30,7 +27,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -30,7 +27,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import BoundingBox from torchvision.prototype.features import BoundingBox, Label, EncodedImage
class VOCDatasetInfo(DatasetInfo): class VOCDatasetInfo(DatasetInfo):
...@@ -50,7 +47,6 @@ class VOC(Dataset): ...@@ -50,7 +47,6 @@ class VOC(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return VOCDatasetInfo( return VOCDatasetInfo(
"voc", "voc",
type=DatasetType.IMAGE,
homepage="http://host.robots.ox.ac.uk/pascal/VOC/", homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict( valid_options=dict(
split=("train", "val", "trainval", "test"), split=("train", "val", "trainval", "test"),
...@@ -99,40 +95,52 @@ class VOC(Dataset): ...@@ -99,40 +95,52 @@ class VOC(Dataset):
else: else:
return None return None
def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
objects = result["annotation"]["object"]
bboxes = [obj["bndbox"] for obj in objects] def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] anns = self._parse_detection_ann(buffer)
return BoundingBox(bboxes) instances = anns["object"]
return dict(
bounding_boxes=BoundingBox(
[
[int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for instance in instances
],
format="xyxy",
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self.categories.index(instance["name"]) for instance in instances], categories=self.categories
),
)
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return dict(segmentation=EncodedImage.from_file(buffer))
def _collate_and_decode_sample( def _prepare_sample(
self, self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
*, *,
config: DatasetConfig, prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
_, image_data = split_and_image_data _, image_data = split_and_image_data
image_path, image_buffer = image_data image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data ann_path, ann_buffer = ann_data
image = decoder(image_buffer) if decoder else image_buffer return dict(
prepare_ann_fn(ann_buffer),
if config.task == "detection": image_path=image_path,
ann = self._decode_detection_ann(ann_buffer) image=EncodedImage.from_file(image_buffer),
else: # config.task == "segmentation": ann_path=ann_path,
ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment] )
return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann)
def _make_datapipe( def _make_datapipe(
self, self,
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
...@@ -158,4 +166,25 @@ class VOC(Dataset): ...@@ -158,4 +166,25 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) return Mapper(
dp,
functools.partial(
self._prepare_sample,
prepare_ann_fn=self._prepare_detection_ann
if config.task == "detection"
else self._prepare_segmentation_ann,
),
)
def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
return self._classify_archive(data, config=config) == 2
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(task="detection")
resource = self.resources(config)[0]
dp = resource.load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config))
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
import functools import functools
import io
import os import os
import os.path import os.path
import pathlib import pathlib
from typing import Callable, Optional, Collection from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any
from typing import Union, Tuple, List, Dict, Any
import torch from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener
from torchdata.datapipes.iter import IterDataPipe, FileLister, FileOpener, Mapper, Shuffler, Filter from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.features import Label, EncodedImage, EncodedData
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding
__all__ = ["from_data_folder", "from_image_folder"] __all__ = ["from_data_folder", "from_image_folder"]
...@@ -20,29 +17,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: ...@@ -20,29 +17,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
def _collate_and_decode_data( def _prepare_sample(
data: Tuple[str, io.IOBase], data: Tuple[str, BinaryIO],
*, *,
root: pathlib.Path, root: pathlib.Path,
categories: List[str], categories: List[str],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
path, buffer = data path, buffer = data
data = decoder(buffer) if decoder else buffer
category = pathlib.Path(path).relative_to(root).parts[0] category = pathlib.Path(path).relative_to(root).parts[0]
label = torch.tensor(categories.index(category))
return dict( return dict(
path=path, path=path,
data=data, data=EncodedData.from_file(buffer),
label=label, label=Label.from_category(category, categories=categories),
category=category,
) )
def from_data_folder( def from_data_folder(
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
valid_extensions: Optional[Collection[str]] = None, valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True, recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]: ) -> Tuple[IterDataPipe, List[str]]:
...@@ -52,26 +44,22 @@ def from_data_folder( ...@@ -52,26 +44,22 @@ def from_data_folder(
dp = FileLister(str(root), recursive=recursive, masks=masks) dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
dp = FileOpener(dp, mode="rb") dp = FileOpener(dp, mode="rb")
return ( return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories
Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
categories,
)
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = sample.pop("data") sample["image"] = EncodedImage(sample.pop("data").data)
return sample return sample
def from_image_folder( def from_image_folder(
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any, **kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]: ) -> Tuple[IterDataPipe, List[str]]:
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs)
return Mapper(dp, _data_to_image_key), categories return Mapper(dp, _data_to_image_key), categories
import io
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms.functional import pil_to_tensor
__all__ = ["raw", "pil"]
def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.")
def pil(buffer: io.IOBase) -> features.Image:
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
from . import _internal from . import _internal # usort: skip
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._dataset import DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
import abc import abc
import csv import csv
import enum
import importlib import importlib
import io
import itertools import itertools
import os import os
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple, Collection from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
from .._home import use_sharded_dataset from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource from ._resource import OnlineResource
class DatasetType(enum.Enum):
RAW = enum.auto()
IMAGE = enum.auto()
class DatasetConfig(FrozenBunch): class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config) # This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable. # and partial() requires the parameters to be hashable.
...@@ -34,7 +25,6 @@ class DatasetInfo: ...@@ -34,7 +25,6 @@ class DatasetInfo:
self, self,
name: str, name: str,
*, *,
type: Union[str, DatasetType],
dependencies: Collection[str] = (), dependencies: Collection[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None, citation: Optional[str] = None,
...@@ -44,7 +34,6 @@ class DatasetInfo: ...@@ -44,7 +34,6 @@ class DatasetInfo:
extra: Optional[Dict[str, Any]] = None, extra: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.name = name.lower() self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
self.dependecies = dependencies self.dependecies = dependencies
...@@ -163,7 +152,6 @@ class Dataset(abc.ABC): ...@@ -163,7 +152,6 @@ class Dataset(abc.ABC):
resource_dps: List[IterDataPipe], resource_dps: List[IterDataPipe],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
pass pass
...@@ -175,7 +163,6 @@ class Dataset(abc.ABC): ...@@ -175,7 +163,6 @@ class Dataset(abc.ABC):
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
config: Optional[DatasetConfig] = None, config: Optional[DatasetConfig] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
skip_integrity_check: bool = False, skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
if not config: if not config:
...@@ -190,7 +177,7 @@ class Dataset(abc.ABC): ...@@ -190,7 +177,7 @@ class Dataset(abc.ABC):
resource_dps = [ resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
] ]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError raise NotImplementedError
import enum import enum
import functools import functools
import gzip import gzip
import io
import lzma import lzma
import mmap
import os import os
import os.path import os.path
import pathlib import pathlib
import pickle import pickle
import platform
from typing import BinaryIO from typing import BinaryIO
from typing import ( from typing import (
Sequence, Sequence,
...@@ -25,27 +22,24 @@ from typing import ( ...@@ -25,27 +22,24 @@ from typing import (
) )
from typing import cast from typing import cast
import numpy as np
import PIL.Image
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.utils.data import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"BUILTIN_DIR", "BUILTIN_DIR",
"read_mat", "read_mat",
"image_buffer_from_array",
"MappingIterator", "MappingIterator",
"Enumerator", "Enumerator",
"getitem", "getitem",
"path_accessor", "path_accessor",
"path_comparator", "path_comparator",
"Decompressor", "Decompressor",
"fromfile",
"read_flo", "read_flo",
"hint_sharding", "hint_sharding",
] ]
...@@ -59,7 +53,7 @@ INFINITE_BUFFER_SIZE = 1_000_000_000 ...@@ -59,7 +53,7 @@ INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
try: try:
import scipy.io as sio import scipy.io as sio
except ImportError as error: except ImportError as error:
...@@ -71,14 +65,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: ...@@ -71,14 +65,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
return sio.loadmat(buffer, **kwargs) return sio.loadmat(buffer, **kwargs)
def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO:
image = PIL.Image.fromarray(array)
buffer = io.BytesIO()
image.save(buffer, format=format)
buffer.seek(0)
return buffer
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe self.datapipe = datapipe
...@@ -142,17 +128,17 @@ class CompressionType(enum.Enum): ...@@ -142,17 +128,17 @@ class CompressionType(enum.Enum):
LZMA = "lzma" LZMA = "lzma"
class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]):
types = CompressionType types = CompressionType
_DECOMPRESSORS = { _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = {
types.GZIP: lambda file: gzip.GzipFile(fileobj=file), types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)),
types.LZMA: lambda file: lzma.LZMAFile(file), types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)),
} }
def __init__( def __init__(
self, self,
datapipe: IterDataPipe[Tuple[str, io.IOBase]], datapipe: IterDataPipe[Tuple[str, BinaryIO]],
*, *,
type: Optional[Union[str, CompressionType]] = None, type: Optional[Union[str, CompressionType]] = None,
) -> None: ) -> None:
...@@ -174,7 +160,7 @@ class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -174,7 +160,7 @@ class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]):
else: else:
raise RuntimeError("FIXME") raise RuntimeError("FIXME")
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]:
for path, file in self.datapipe: for path, file in self.datapipe:
type = self._detect_compression_type(path) type = self._detect_compression_type(path)
decompressor = self._DECOMPRESSORS[type] decompressor = self._DECOMPRESSORS[type]
...@@ -257,69 +243,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st ...@@ -257,69 +243,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st
return dp return dp
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)
buffer: Union[memoryview, bytearray]
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
buffer = _read_mutable_buffer_fallback(file, count, item_size)
else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
def read_flo(file: BinaryIO) -> torch.Tensor: def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH": if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file") raise ValueError("Magic number incorrect. Invalid .flo file")
...@@ -329,9 +252,9 @@ def read_flo(file: BinaryIO) -> torch.Tensor: ...@@ -329,9 +252,9 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
return flow.reshape((height, width, 2)).permute((2, 0, 1)) return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
return ShardingFilter(datapipe) return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)
from ._bounding_box import BoundingBoxFormat, BoundingBox from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._feature import Feature, DEFAULT from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._image import Image, ColorSpace from ._feature import _Feature
from ._label import Label from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask
import enum from __future__ import annotations
import functools
from typing import Callable, Union, Tuple, Dict, Any, Optional, cast from typing import Any, Tuple, Union, Optional
import torch import torch
from torchvision.prototype.utils._internal import StrEnum from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT from ._feature import _Feature
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(StrEnum):
# this is just for test purposes XYXY = StrEnum.auto()
_SENTINEL = -1 XYWH = StrEnum.auto()
XYXY = enum.auto() CXCYWH = StrEnum.auto()
XYWH = enum.auto()
CXCYWH = enum.auto()
def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return input.unbind(dim=-1) # type: ignore[return-value]
def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return torch.stack((a, b, c, d), dim=-1)
def format_converter_wrapper(
part_converter: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
):
def wrapper(input: torch.Tensor) -> torch.Tensor:
return from_parts(*part_converter(*to_parts(input)))
return wrapper
@format_converter_wrapper
def xywh_to_xyxy(
x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = x
y1 = y
x2 = x + w
y2 = y + h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_xywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = x1
y = y1
w = x2 - x1
h = y2 - y1
return x, y, w, h
@format_converter_wrapper
def cxcywh_to_xyxy(
cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return x1, y1, x2, y2
class BoundingBox(_Feature):
@format_converter_wrapper
def xyxy_to_cxcywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return cx, cy, w, h
class BoundingBox(Feature):
formats = BoundingBoxFormat
format: BoundingBoxFormat format: BoundingBoxFormat
image_size: Tuple[int, int] image_size: Tuple[int, int]
@classmethod def __new__(
def _parse_meta_data(
cls, cls,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] data: Any,
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] *,
) -> Dict[str, Tuple[Any, Any]]: dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int],
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format] format = BoundingBoxFormat[format]
format_fallback = BoundingBoxFormat.XYXY
return dict(
format=(format, format_fallback),
image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)),
)
_TO_XYXY_MAP = { bounding_box._metadata.update(dict(format=format, image_size=image_size))
BoundingBoxFormat.XYWH: xywh_to_xyxy,
BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy,
}
_FROM_XYXY_MAP = {
BoundingBoxFormat.XYWH: xyxy_to_xywh,
BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh,
}
@classmethod return bounding_box
def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]:
if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH):
if format != BoundingBoxFormat.XYXY:
data = cls._TO_XYXY_MAP[format](data)
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
*_, w, h = to_parts(data)
if data.dtype.is_floating_point:
w = w.ceil()
h = h.ceil()
return int(h.max()), int(w.max())
@classmethod def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
def from_parts( # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
cls, # promote this out of the prototype state
a,
b,
c,
d,
*,
like: Optional["BoundingBox"] = None,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> "BoundingBox":
return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format)
def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # import at runtime to avoid cyclic imports
return to_parts(self) from torchvision.prototype.transforms.kernels import convert_bounding_box_format
def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format] format = BoundingBoxFormat[format]
if format == self.format: return BoundingBox.new_like(
return cast(BoundingBox, self.clone()) self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)
data = self
if self.format != BoundingBoxFormat.XYXY:
data = self._TO_XYXY_MAP[self.format](data)
if format != BoundingBoxFormat.XYXY:
data = self._FROM_XYXY_MAP[format](data)
return BoundingBox(data, like=self, format=format)
import os
import sys
from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any
import PIL.Image
import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from ._feature import _Feature
from ._image import Image
D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature):
@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super()._to_tensor(data, dtype=dtype, device=device)
@classmethod
def from_file(cls: Type[D], file: BinaryIO) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder))
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D:
with open(path, "rb") as file:
return cls.from_file(file)
class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8
@property
def image_size(self) -> Tuple[int, int]:
if not hasattr(self, "_image_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._image_size = image.height, image.width
return self._image_size
def decode(self) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil
return Image(decode_image_with_pil(self))
class EncodedVideo(EncodedData):
pass
from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
import torch import torch
from torch._C import _TensorBase, DisableTorchFunction from torch._C import _TensorBase, DisableTorchFunction
from torchvision.prototype.utils._internal import add_suggestion
F = TypeVar("F", bound="Feature") F = TypeVar("F", bound="_Feature")
DEFAULT = object() class _Feature(torch.Tensor):
class Feature(torch.Tensor):
_META_ATTRS: Set[str] = set() _META_ATTRS: Set[str] = set()
_meta_data: Dict[str, Any] _metadata: Dict[str, Any]
def __init_subclass__(cls): def __init_subclass__(cls) -> None:
# In order to help static type checkers, we require subclasses of `Feature` add the meta data attributes """
# as static class annotations: For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
# By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
# >>> class Foo(Feature): properties to have the same convenient access as regular attributes.
# ... bar: str
# ... baz: Optional[str] >>> class Foo(_Feature):
# ... bar: str
# Internally, this information is used twofold: ... baz: Optional[str]
# >>> foo = Foo()
# 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference >>> foo.bar
# to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime >>> foo.baz
# access. This happens in this method.
# 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
# unknown arguments. """
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")} meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]: for super_cls in cls.__mro__[1:]:
if super_cls is Feature: if super_cls is _Feature:
break break
meta_attrs.update(super_cls._META_ATTRS) meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS)
cls._META_ATTRS = meta_attrs cls._META_ATTRS = meta_attrs
for attr in meta_attrs: for name in meta_attrs:
setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr])) setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name])))
def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): def __new__(
unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS cls: Type[F],
if unknown_meta_attrs: data: Any,
unknown_meta_attr = sorted(unknown_meta_attrs)[0] *,
raise TypeError( dtype: Optional[torch.dtype] = None,
add_suggestion( device: Optional[Union[torch.device, str]] = None,
f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.", ) -> F:
word=unknown_meta_attr, if isinstance(device, str):
possibilities=sorted(cls._META_ATTRS), device = torch.device(device)
) feature = cast(
) F,
torch.Tensor._make_subclass(
if like is not None: cast(_TensorBase, cls),
dtype = dtype or like.dtype cls._to_tensor(data, dtype=dtype, device=device),
device = device or like.device # requires_grad
data = cls._to_tensor(data, dtype=dtype, device=device) False,
requires_grad = False ),
self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad) )
feature._metadata = dict()
meta_data = dict() return feature
for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items():
if explicit is not DEFAULT:
value = explicit
elif like is not None:
value = getattr(like, attr)
else:
value = fallback(data) if callable(fallback) else fallback
meta_data[attr] = value
self._meta_data = meta_data
return self
@classmethod @classmethod
def _to_tensor(cls, data, *, dtype, device): def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device) return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod @classmethod
def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]: def new_like(
return dict() cls: Type[F],
other: F,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
**metadata: Any,
) -> F:
_metadata = other._metadata.copy()
_metadata.update(metadata)
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)
@classmethod @classmethod
def __torch_function__( def __torch_function__(
...@@ -89,12 +84,37 @@ class Feature(torch.Tensor): ...@@ -89,12 +84,37 @@ class Feature(torch.Tensor):
args: Sequence[Any] = (), args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None, kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators.
Exceptions to this are:
- :func:`torch.clone`
- :meth:`torch.Tensor.to`
"""
kwargs = kwargs or dict()
with DisableTorchFunction(): with DisableTorchFunction():
output = func(*args, **(kwargs or dict())) output = func(*args, **kwargs)
if func is not torch.Tensor.clone:
return output
return cls(output, like=args[0]) if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else:
return output
def __repr__(self) -> str: def __repr__(self) -> str:
return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__)
from typing import Dict, Any, Union, Tuple from __future__ import annotations
import warnings
from typing import Any, Optional, Union, Tuple, cast
import torch import torch
from torchvision.prototype.utils._internal import StrEnum from torchvision.prototype.utils._internal import StrEnum
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid
from ._feature import Feature, DEFAULT from ._bounding_box import BoundingBox
from ._feature import _Feature
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
# this is just for test purposes
_SENTINEL = -1
OTHER = 0 OTHER = 0
GRAYSCALE = 1 GRAYSCALE = 1
RGB = 3 RGB = 3
class Image(Feature): class Image(_Feature):
color_spaces = ColorSpace
color_space: ColorSpace color_space: ColorSpace
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
color_space: Optional[Union[ColorSpace, str]] = None,
) -> Image:
image = super().__new__(cls, data, dtype=dtype, device=device)
if color_space is None:
color_space = cls.guess_color_space(image)
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace[color_space]
image._metadata.update(dict(color_space=color_space))
return image
@classmethod @classmethod
def _to_tensor(cls, data, *, dtype, device): def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
tensor = torch.as_tensor(data, dtype=dtype, device=device) tensor = super()._to_tensor(data, dtype=dtype, device=device)
if tensor.ndim == 2: if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
elif tensor.ndim != 3:
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
return tensor return tensor
@classmethod @property
def _parse_meta_data( def image_size(self) -> Tuple[int, int]:
cls, return cast(Tuple[int, int], self.shape[-2:])
color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]: @property
if isinstance(color_space, str): def num_channels(self) -> int:
color_space = ColorSpace[color_space] return self.shape[-3]
return dict(color_space=(color_space, cls.guess_color_space))
@staticmethod @staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace: def guess_color_space(data: torch.Tensor) -> ColorSpace:
...@@ -50,3 +74,13 @@ class Image(Feature): ...@@ -50,3 +74,13 @@ class Image(Feature):
return ColorSpace.RGB return ColorSpace.RGB
else: else:
return ColorSpace.OTHER return ColorSpace.OTHER
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
from typing import Dict, Any, Optional, Tuple from __future__ import annotations
from ._feature import Feature, DEFAULT from typing import Any, Optional, Sequence, cast
import torch
from torchvision.prototype.utils._internal import apply_recursively
class Label(Feature): from ._feature import _Feature
category: Optional[str]
class Label(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device)
label._metadata.update(dict(categories=categories))
return label
@classmethod @classmethod
def _parse_meta_data( def from_category(cls, category: str, *, categories: Sequence[str]) -> Label:
return cls(categories.index(category), categories=categories)
def to_categories(self) -> Any:
if not self.categories:
raise RuntimeError()
return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist())
class OneHotLabel(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls, cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment] data: Any,
) -> Dict[str, Tuple[Any, Any]]: *,
return dict(category=(category, None)) dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
one_hot_label._metadata.update(dict(categories=categories))
return one_hot_label
from ._feature import _Feature
class SegmentationMask(_Feature):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment