Unverified Commit d32bc4ba authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Revamp prototype features and transforms (#5407)

* revamp prototype features (#5283)

* remove decoding from prototype datasets (#5287)

* remove decoder from prototype datasets

* remove unused imports

* cleanup

* fix readme

* use OneHotLabel in SEMEION

* improve voc implementation

* revert unrelated changes

* fix semeion mock data

* fix pcam

* readd functional transforms API to prototype (#5295)

* readd functional transforms

* cleanup

* add missing imports

* remove __torch_function__ dispatch

* readd repr

* readd empty line

* add test for scriptability

* remove function copy

* change import from functional tensor transforms to just functional

* fix import

* fix test

* fix prototype features and functional transforms after review (#5377)

* fix prototype functional transforms after review

* address features review

* make mypy more strict on prototype features

* make mypy more strict for prototype transforms

* fix annotation

* fix kernel tests

* add automatic feature type dispatch to functional transforms (#5323)

* add auto dispatch

* fix missing arguments error message

* remove pil kernel for erase

* automate feature specific parameter detection

* fix typos

* cleanup dispatcher call

* remove __torch_function__ from transform dispatch

* remove auto-generation

* revert unrelated changes

* remove implements decorator

* change register parameter order

* change order of transforms for readability

* add documentation for __torch_function__

* fix mypy

* inline check for support

* refactor kernel registering process

* refactor dispatch to be a regular decorator

* split kernels and dispatchers

* remove sentinels

* replace pass with ...

* appease mypy

* make single kernel dispatchers more concise

* make dispatcher signatures more generic

* make kernel checking more strict

* revert doc changes

* address Franciscos comments

* remove inplace

* rename kernel test module

* fix inplace

* remove special casing for pil and vanilla tensors

* address comments

* update docs

* cleanup features / transforms feature branch (#5406)

* mark candidates for removal

* align signature of resize_bounding_box with corresponding image kernel

* fix documentation of Feature

* remove interpolation mode and antialias option from resize_segmentation_mask

* remove or privatize functionality in features / datasets / transforms
parent f2f490b1
import abc
import functools
import io
import operator
import pathlib
import string
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence
import torch
from torchdata.datapipes.iter import (
......@@ -13,24 +12,21 @@ from torchdata.datapipes.iter import (
Mapper,
Zipper,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetType,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
image_buffer_from_array,
Decompressor,
INFINITE_BUFFER_SIZE,
fromfile,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Image, Label
from torchvision.prototype.utils._internal import fromfile
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
......@@ -105,31 +101,15 @@ class _MNISTBase(Dataset):
def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]:
return None, None
def _collate_and_decode(
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
image, label = data
if decoder is raw:
image = Image(image)
else:
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)])
return dict(image=image, label=label)
return dict(
image=Image(image),
label=Label(label, dtype=torch.int64, categories=self.categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, labels_dp = resource_dps
start, stop = self.start_and_stop(config)
......@@ -143,14 +123,13 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._prepare_sample, config=config))
class MNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"mnist",
type=DatasetType.RAW,
categories=10,
homepage="http://yann.lecun.com/exdb/mnist",
valid_options=dict(
......@@ -183,7 +162,6 @@ class FashionMNIST(MNIST):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fashionmnist",
type=DatasetType.RAW,
categories=(
"T-shirt/top",
"Trouser",
......@@ -215,7 +193,6 @@ class KMNIST(MNIST):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"kmnist",
type=DatasetType.RAW,
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
homepage="http://codh.rois.ac.jp/kmnist/index.html.en",
valid_options=dict(
......@@ -236,7 +213,6 @@ class EMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"emnist",
type=DatasetType.RAW,
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist",
valid_options=dict(
......@@ -291,13 +267,7 @@ class EMNIST(_MNISTBase):
46: 9,
}
def _collate_and_decode(
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
......@@ -310,14 +280,10 @@ class EMNIST(_MNISTBase):
image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label)
return super()._collate_and_decode(data, config=config, decoder=decoder)
return super()._prepare_sample(data, config=config)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, labels_dp = Demultiplexer(
......@@ -327,14 +293,13 @@ class EMNIST(_MNISTBase):
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder)
return super()._make_datapipe([images_dp, labels_dp], config=config)
class QMNIST(_MNISTBase):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"qmnist",
type=DatasetType.RAW,
categories=10,
homepage="https://github.com/facebookresearch/qmnist",
valid_options=dict(
......@@ -376,16 +341,10 @@ class QMNIST(_MNISTBase):
return start, stop
def _collate_and_decode(
self,
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]:
image, ann = data
label, *extra_anns = ann
sample = super()._collate_and_decode((image, label), config=config, decoder=decoder)
sample = super()._prepare_sample((image, label), config=config)
sample.update(
dict(
......
import enum
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
......@@ -12,7 +9,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -22,7 +18,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor,
path_comparator,
)
from torchvision.prototype.features import Label
from torchvision.prototype.features import Label, EncodedImage
class OxfordIITPetDemux(enum.IntEnum):
......@@ -34,7 +30,6 @@ class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"oxford-iiit-pet",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict(
split=("trainval", "test"),
......@@ -66,18 +61,8 @@ class OxfordIITPet(Dataset):
def _filter_segmentations(self, data: Tuple[str, Any]) -> bool:
return not pathlib.Path(data[0]).name.startswith(".")
def _decode_classification_data(self, data: Dict[str, str]) -> Dict[str, Any]:
label_idx = int(data["label"]) - 1
return dict(
label=Label(label_idx, category=self.info.categories[label_idx]),
species="cat" if data["species"] == "1" else "dog",
)
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[Dict[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
def _prepare_sample(
self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]
) -> Dict[str, Any]:
ann_data, image_data = data
classification_data, segmentation_data = ann_data
......@@ -85,19 +70,16 @@ class OxfordIITPet(Dataset):
image_path, image_buffer = image_data
return dict(
self._decode_classification_data(classification_data),
label=Label(int(classification_data["label"]) - 1, categories=self.categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path,
segmentation=decoder(segmentation_buffer) if decoder else segmentation_buffer,
segmentation=EncodedImage.from_file(segmentation_buffer),
image_path=image_path,
image=decoder(image_buffer) if decoder else image_buffer,
image=EncodedImage.from_file(image_buffer),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
......@@ -137,7 +119,7 @@ class OxfordIITPet(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
......
import io
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
from typing import Any, Dict, List, Optional, Tuple, Iterator
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype import features
from torchvision.prototype.datasets.utils import (
......@@ -10,7 +9,6 @@ from torchvision.prototype.datasets.utils import (
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
GDriveResource,
)
from torchvision.prototype.datasets.utils._internal import (
......@@ -46,7 +44,6 @@ class PCAM(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"pcam",
type=DatasetType.RAW,
homepage="https://github.com/basveeling/pcam",
categories=2,
valid_options=dict(split=("train", "test", "val")),
......@@ -98,7 +95,7 @@ class PCAM(Dataset):
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
]
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
image, target = data # They're both numpy arrays at this point
return {
......@@ -107,11 +104,7 @@ class PCAM(Dataset):
}
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
......@@ -122,4 +115,4 @@ class PCAM(Dataset):
dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode)
return Mapper(dp, self._prepare_sample)
import functools
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -20,7 +17,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -31,20 +27,17 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Feature
from torchvision.prototype.features import _Feature, EncodedImage
class SBD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"sbd",
type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict(
split=("train", "val", "train_noval"),
boundaries=(True, False),
segmentation=(False, True),
),
)
......@@ -75,50 +68,21 @@ class SBD(Dataset):
else:
return None
def _decode_ann(
self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
raw_anns = data["GTcls"][0]
raw_boundaries = raw_anns["Boundaries"][0]
raw_segmentation = raw_anns["Segmentation"][0]
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries = (
Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
if decode_boundaries
else None
)
segmentation = Feature(raw_segmentation) if decode_segmentation else None
return boundaries, segmentation
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
image = decoder(image_buffer) if decoder else image_buffer
if config.boundaries or config.segmentation:
boundaries, segmentation = self._decode_ann(
read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation
)
else:
boundaries = segmentation = None
anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]
return dict(
image_path=image_path,
image=image,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
boundaries=boundaries,
segmentation=segmentation,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
segmentation=_Feature(anns["Segmentation"].item()),
)
def _make_datapipe(
......@@ -126,7 +90,6 @@ class SBD(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps
......@@ -138,10 +101,10 @@ class SBD(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if config.split == "train_noval":
split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("stem", config.split))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
......@@ -155,7 +118,7 @@ class SBD(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
resources = self.resources(self.default_config)
......
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple
import torch
from torchdata.datapipes.iter import (
......@@ -8,24 +6,21 @@ from torchdata.datapipes.iter import (
Mapper,
CSVParser,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, OneHotLabel
class SEMEION(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"semeion",
type=DatasetType.RAW,
categories=10,
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
)
......@@ -37,34 +32,22 @@ class SEMEION(Dataset):
)
return [data]
def _collate_and_decode_sample(
self,
data: Tuple[str, ...],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16)
label_data = [int(label) for label in data[256:] if label]
def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]:
image_data, label_data = data[:256], data[256:-1]
if decoder is raw:
image = Image(image_data.unsqueeze(0))
else:
image_buffer = image_buffer_from_array(image_data.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx]))
return dict(
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self.categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return dp
return Mapper(dp, self._prepare_sample)
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple, BinaryIO
import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
UnBatcher,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
read_mat,
hint_sharding,
hint_shuffling,
image_buffer_from_array,
)
from torchvision.prototype.features import Label, Image
......@@ -31,7 +25,6 @@ class SVHN(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"svhn",
type=DatasetType.RAW,
dependencies=("scipy",),
categories=10,
homepage="http://ufldl.stanford.edu/housenumbers/",
......@@ -52,7 +45,7 @@ class SVHN(Dataset):
return [data]
def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]:
def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]:
_, buffer = data
content = read_mat(buffer)
return list(
......@@ -62,23 +55,12 @@ class SVHN(Dataset):
)
)
def _collate_and_decode_sample(
self,
data: Tuple[np.ndarray, np.ndarray],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]:
image_array, label_array = data
if decoder is raw:
image = Image(image_array.transpose((2, 0, 1)))
else:
image_buffer = image_buffer_from_array(image_array)
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
return dict(
image=image,
label=Label(int(label_array) % 10),
image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10, categories=self.categories),
)
def _make_datapipe(
......@@ -86,11 +68,10 @@ class SVHN(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable
from xml.etree import ElementTree
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -20,7 +18,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
path_accessor,
......@@ -30,7 +27,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox
from torchvision.prototype.features import BoundingBox, Label, EncodedImage
class VOCDatasetInfo(DatasetInfo):
......@@ -50,7 +47,6 @@ class VOC(Dataset):
def _make_info(self) -> DatasetInfo:
return VOCDatasetInfo(
"voc",
type=DatasetType.IMAGE,
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict(
split=("train", "val", "trainval", "test"),
......@@ -99,40 +95,52 @@ class VOC(Dataset):
else:
return None
def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor:
result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type]
objects = result["annotation"]["object"]
bboxes = [obj["bndbox"] for obj in objects]
bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes]
return BoundingBox(bboxes)
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
instances = anns["object"]
return dict(
bounding_boxes=BoundingBox(
[
[int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for instance in instances
],
format="xyxy",
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self.categories.index(instance["name"]) for instance in instances], categories=self.categories
),
)
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return dict(segmentation=EncodedImage.from_file(buffer))
def _collate_and_decode_sample(
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
image = decoder(image_buffer) if decoder else image_buffer
if config.task == "detection":
ann = self._decode_detection_ann(ann_buffer)
else: # config.task == "segmentation":
ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment]
return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann)
return dict(
prepare_ann_fn(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
......@@ -158,4 +166,25 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
return Mapper(
dp,
functools.partial(
self._prepare_sample,
prepare_ann_fn=self._prepare_detection_ann
if config.task == "detection"
else self._prepare_segmentation_ann,
),
)
def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
return self._classify_archive(data, config=config) == 2
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(task="detection")
resource = self.resources(config)[0]
dp = resource.load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config))
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
import functools
import io
import os
import os.path
import pathlib
from typing import Callable, Optional, Collection
from typing import Union, Tuple, List, Dict, Any
from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any
import torch
from torchdata.datapipes.iter import IterDataPipe, FileLister, FileOpener, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding
from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, EncodedImage, EncodedData
__all__ = ["from_data_folder", "from_image_folder"]
......@@ -20,29 +17,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
def _collate_and_decode_data(
data: Tuple[str, io.IOBase],
def _prepare_sample(
data: Tuple[str, BinaryIO],
*,
root: pathlib.Path,
categories: List[str],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
path, buffer = data
data = decoder(buffer) if decoder else buffer
category = pathlib.Path(path).relative_to(root).parts[0]
label = torch.tensor(categories.index(category))
return dict(
path=path,
data=data,
label=label,
category=category,
data=EncodedData.from_file(buffer),
label=Label.from_category(category, categories=categories),
)
def from_data_folder(
root: Union[str, pathlib.Path],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]:
......@@ -52,26 +44,22 @@ def from_data_folder(
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = FileOpener(dp, mode="rb")
return (
Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
categories,
)
return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = sample.pop("data")
sample["image"] = EncodedImage(sample.pop("data").data)
return sample
def from_image_folder(
root: Union[str, pathlib.Path],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]:
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs)
dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs)
return Mapper(dp, _data_to_image_key), categories
import io
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms.functional import pil_to_tensor
__all__ = ["raw", "pil"]
def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.")
def pil(buffer: io.IOBase) -> features.Image:
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from . import _internal # usort: skip
from ._dataset import DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
import abc
import csv
import enum
import importlib
import io
import itertools
import os
import pathlib
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple, Collection
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource
class DatasetType(enum.Enum):
RAW = enum.auto()
IMAGE = enum.auto()
class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable.
......@@ -34,7 +25,6 @@ class DatasetInfo:
self,
name: str,
*,
type: Union[str, DatasetType],
dependencies: Collection[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
......@@ -44,7 +34,6 @@ class DatasetInfo:
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
self.dependecies = dependencies
......@@ -163,7 +152,6 @@ class Dataset(abc.ABC):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
pass
......@@ -175,7 +163,6 @@ class Dataset(abc.ABC):
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
......@@ -190,7 +177,7 @@ class Dataset(abc.ABC):
resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
return self._make_datapipe(resource_dps, config=config)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
import enum
import functools
import gzip
import io
import lzma
import mmap
import os
import os.path
import pathlib
import pickle
import platform
from typing import BinaryIO
from typing import (
Sequence,
......@@ -25,27 +22,24 @@ from typing import (
)
from typing import cast
import numpy as np
import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"read_mat",
"image_buffer_from_array",
"MappingIterator",
"Enumerator",
"getitem",
"path_accessor",
"path_comparator",
"Decompressor",
"fromfile",
"read_flo",
"hint_sharding",
]
......@@ -59,7 +53,7 @@ INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
......@@ -71,14 +65,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
return sio.loadmat(buffer, **kwargs)
def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO:
image = PIL.Image.fromarray(array)
buffer = io.BytesIO()
image.save(buffer, format=format)
buffer.seek(0)
return buffer
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
......@@ -142,17 +128,17 @@ class CompressionType(enum.Enum):
LZMA = "lzma"
class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]):
class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]):
types = CompressionType
_DECOMPRESSORS = {
types.GZIP: lambda file: gzip.GzipFile(fileobj=file),
types.LZMA: lambda file: lzma.LZMAFile(file),
_DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = {
types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)),
types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)),
}
def __init__(
self,
datapipe: IterDataPipe[Tuple[str, io.IOBase]],
datapipe: IterDataPipe[Tuple[str, BinaryIO]],
*,
type: Optional[Union[str, CompressionType]] = None,
) -> None:
......@@ -174,7 +160,7 @@ class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]):
else:
raise RuntimeError("FIXME")
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]:
for path, file in self.datapipe:
type = self._detect_compression_type(path)
decompressor = self._DECOMPRESSORS[type]
......@@ -257,69 +243,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st
return dp
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)
buffer: Union[memoryview, bytearray]
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
buffer = _read_mutable_buffer_fallback(file, count, item_size)
else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
......@@ -329,9 +252,9 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)
from ._bounding_box import BoundingBoxFormat, BoundingBox
from ._feature import Feature, DEFAULT
from ._image import Image, ColorSpace
from ._label import Label
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature
from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask
import enum
import functools
from typing import Callable, Union, Tuple, Dict, Any, Optional, cast
from __future__ import annotations
from typing import Any, Tuple, Union, Optional
import torch
from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT
from ._feature import _Feature
class BoundingBoxFormat(StrEnum):
# this is just for test purposes
_SENTINEL = -1
XYXY = enum.auto()
XYWH = enum.auto()
CXCYWH = enum.auto()
def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return input.unbind(dim=-1) # type: ignore[return-value]
def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return torch.stack((a, b, c, d), dim=-1)
def format_converter_wrapper(
part_converter: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
):
def wrapper(input: torch.Tensor) -> torch.Tensor:
return from_parts(*part_converter(*to_parts(input)))
return wrapper
@format_converter_wrapper
def xywh_to_xyxy(
x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = x
y1 = y
x2 = x + w
y2 = y + h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_xywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = x1
y = y1
w = x2 - x1
h = y2 - y1
return x, y, w, h
XYXY = StrEnum.auto()
XYWH = StrEnum.auto()
CXCYWH = StrEnum.auto()
@format_converter_wrapper
def cxcywh_to_xyxy(
cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_cxcywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return cx, cy, w, h
class BoundingBox(Feature):
formats = BoundingBoxFormat
class BoundingBox(_Feature):
format: BoundingBoxFormat
image_size: Tuple[int, int]
@classmethod
def _parse_meta_data(
def __new__(
cls,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int],
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
if isinstance(format, str):
format = BoundingBoxFormat[format]
format_fallback = BoundingBoxFormat.XYXY
return dict(
format=(format, format_fallback),
image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)),
)
_TO_XYXY_MAP = {
BoundingBoxFormat.XYWH: xywh_to_xyxy,
BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy,
}
_FROM_XYXY_MAP = {
BoundingBoxFormat.XYWH: xyxy_to_xywh,
BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh,
}
bounding_box._metadata.update(dict(format=format, image_size=image_size))
@classmethod
def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]:
if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH):
if format != BoundingBoxFormat.XYXY:
data = cls._TO_XYXY_MAP[format](data)
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
*_, w, h = to_parts(data)
if data.dtype.is_floating_point:
w = w.ceil()
h = h.ceil()
return int(h.max()), int(w.max())
return bounding_box
@classmethod
def from_parts(
cls,
a,
b,
c,
d,
*,
like: Optional["BoundingBox"] = None,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> "BoundingBox":
return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return to_parts(self)
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
if isinstance(format, str):
format = BoundingBoxFormat[format]
if format == self.format:
return cast(BoundingBox, self.clone())
data = self
if self.format != BoundingBoxFormat.XYXY:
data = self._TO_XYXY_MAP[self.format](data)
if format != BoundingBoxFormat.XYXY:
data = self._FROM_XYXY_MAP[format](data)
return BoundingBox(data, like=self, format=format)
return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)
import os
import sys
from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any
import PIL.Image
import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from ._feature import _Feature
from ._image import Image
D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature):
@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super()._to_tensor(data, dtype=dtype, device=device)
@classmethod
def from_file(cls: Type[D], file: BinaryIO) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder))
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D:
with open(path, "rb") as file:
return cls.from_file(file)
class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8
@property
def image_size(self) -> Tuple[int, int]:
if not hasattr(self, "_image_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._image_size = image.height, image.width
return self._image_size
def decode(self) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil
return Image(decode_image_with_pil(self))
class EncodedVideo(EncodedData):
pass
from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
import torch
from torch._C import _TensorBase, DisableTorchFunction
from torchvision.prototype.utils._internal import add_suggestion
F = TypeVar("F", bound="Feature")
F = TypeVar("F", bound="_Feature")
DEFAULT = object()
class Feature(torch.Tensor):
class _Feature(torch.Tensor):
_META_ATTRS: Set[str] = set()
_meta_data: Dict[str, Any]
def __init_subclass__(cls):
# In order to help static type checkers, we require subclasses of `Feature` add the meta data attributes
# as static class annotations:
#
# >>> class Foo(Feature):
# ... bar: str
# ... baz: Optional[str]
#
# Internally, this information is used twofold:
#
# 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference
# to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime
# access. This happens in this method.
# 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for
# unknown arguments.
_metadata: Dict[str, Any]
def __init_subclass__(cls) -> None:
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.
>>> class Foo(_Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz
This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is Feature:
if super_cls is _Feature:
break
meta_attrs.update(super_cls._META_ATTRS)
meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS)
cls._META_ATTRS = meta_attrs
for attr in meta_attrs:
setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr]))
def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs):
unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS
if unknown_meta_attrs:
unknown_meta_attr = sorted(unknown_meta_attrs)[0]
raise TypeError(
add_suggestion(
f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.",
word=unknown_meta_attr,
possibilities=sorted(cls._META_ATTRS),
)
)
if like is not None:
dtype = dtype or like.dtype
device = device or like.device
data = cls._to_tensor(data, dtype=dtype, device=device)
requires_grad = False
self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)
meta_data = dict()
for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items():
if explicit is not DEFAULT:
value = explicit
elif like is not None:
value = getattr(like, attr)
else:
value = fallback(data) if callable(fallback) else fallback
meta_data[attr] = value
self._meta_data = meta_data
return self
for name in meta_attrs:
setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name])))
def __new__(
cls: Type[F],
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
) -> F:
if isinstance(device, str):
device = torch.device(device)
feature = cast(
F,
torch.Tensor._make_subclass(
cast(_TensorBase, cls),
cls._to_tensor(data, dtype=dtype, device=device),
# requires_grad
False,
),
)
feature._metadata = dict()
return feature
@classmethod
def _to_tensor(cls, data, *, dtype, device):
def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod
def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]:
return dict()
def new_like(
cls: Type[F],
other: F,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
**metadata: Any,
) -> F:
_metadata = other._metadata.copy()
_metadata.update(metadata)
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)
@classmethod
def __torch_function__(
......@@ -89,12 +84,37 @@ class Feature(torch.Tensor):
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators.
Exceptions to this are:
- :func:`torch.clone`
- :meth:`torch.Tensor.to`
"""
kwargs = kwargs or dict()
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output
output = func(*args, **kwargs)
return cls(output, like=args[0])
if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else:
return output
def __repr__(self) -> str:
return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__)
return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__)
from typing import Dict, Any, Union, Tuple
from __future__ import annotations
import warnings
from typing import Any, Optional, Union, Tuple, cast
import torch
from torchvision.prototype.utils._internal import StrEnum
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid
from ._feature import Feature, DEFAULT
from ._bounding_box import BoundingBox
from ._feature import _Feature
class ColorSpace(StrEnum):
# this is just for test purposes
_SENTINEL = -1
OTHER = 0
GRAYSCALE = 1
RGB = 3
class Image(Feature):
color_spaces = ColorSpace
class Image(_Feature):
color_space: ColorSpace
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
color_space: Optional[Union[ColorSpace, str]] = None,
) -> Image:
image = super().__new__(cls, data, dtype=dtype, device=device)
if color_space is None:
color_space = cls.guess_color_space(image)
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace[color_space]
image._metadata.update(dict(color_space=color_space))
return image
@classmethod
def _to_tensor(cls, data, *, dtype, device):
tensor = torch.as_tensor(data, dtype=dtype, device=device)
if tensor.ndim == 2:
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
tensor = super()._to_tensor(data, dtype=dtype, device=device)
if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 3:
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
return tensor
@classmethod
def _parse_meta_data(
cls,
color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
if isinstance(color_space, str):
color_space = ColorSpace[color_space]
return dict(color_space=(color_space, cls.guess_color_space))
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], self.shape[-2:])
@property
def num_channels(self) -> int:
return self.shape[-3]
@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
......@@ -50,3 +74,13 @@ class Image(Feature):
return ColorSpace.RGB
else:
return ColorSpace.OTHER
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
from typing import Dict, Any, Optional, Tuple
from __future__ import annotations
from ._feature import Feature, DEFAULT
from typing import Any, Optional, Sequence, cast
import torch
from torchvision.prototype.utils._internal import apply_recursively
class Label(Feature):
category: Optional[str]
from ._feature import _Feature
class Label(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device)
label._metadata.update(dict(categories=categories))
return label
@classmethod
def _parse_meta_data(
def from_category(cls, category: str, *, categories: Sequence[str]) -> Label:
return cls(categories.index(category), categories=categories)
def to_categories(self) -> Any:
if not self.categories:
raise RuntimeError()
return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist())
class OneHotLabel(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None))
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
one_hot_label._metadata.update(dict(categories=categories))
return one_hot_label
from ._feature import _Feature
class SegmentationMask(_Feature):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment