Unverified Commit a9d25721 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[PoC] compatibility layer between stable datasets and prototype transforms (#6663)

parent 17088a68
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
import torchvision.datasets import torchvision.datasets
import torchvision.io import torchvision.io
from common_utils import disable_console_output, get_tmp_dir from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torchvision.transforms.functional import get_dimensions from torchvision.transforms.functional import get_dimensions
...@@ -581,6 +582,28 @@ class DatasetTestCase(unittest.TestCase): ...@@ -581,6 +582,28 @@ class DatasetTestCase(unittest.TestCase):
mock.assert_called() mock.assert_called()
@test_all_configs
def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint
try:
with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
return
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
return
raise error
class ImageDatasetTestCase(DatasetTestCase): class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases. """Abstract base class for image dataset testcases.
...@@ -662,6 +685,15 @@ class VideoDatasetTestCase(DatasetTestCase): ...@@ -662,6 +685,15 @@ class VideoDatasetTestCase(DatasetTestCase):
return wrapper return wrapper
@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
# or use the supported `"TCHW"`
if config.setdefault("output_format", "TCHW") == "THWC":
return
super().test_transforms_v2_wrapper.__wrapped__(self, config)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor. r"""Create a random uint8 tensor.
......
...@@ -763,11 +763,19 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -763,11 +763,19 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return info return info
def _create_annotations(self, image_ids, num_annotations_per_image): def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid( annotations = []
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image annotion_id = 0
) for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
for id, annotation in enumerate(annotations): annotations.append(
annotation["id"] = id dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
)
)
annotion_id += 1
return annotations, dict() return annotations, dict()
def _create_json(self, root, name, content): def _create_json(self, root, name, content):
......
...@@ -4,3 +4,5 @@ from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImage ...@@ -4,3 +4,5 @@ from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImage
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
# type: ignore
from __future__ import annotations
import contextlib
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"]
# TODO: naming!
def wrap_dataset_for_transforms_v2(dataset):
return VisionDatasetDatapointWrapper(dataset)
class WrapperFactories(dict):
def register(self, dataset_cls):
def decorator(wrapper_factory):
self[dataset_cls] = wrapper_factory
return wrapper_factory
return decorator
# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
# we have access to the dataset instance.
WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset):
def __init__(self, dataset):
dataset_cls = type(dataset)
wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls)
if wrapper_factory is None:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
if dataset_cls in datasets.__dict__.values():
msg = (
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues."
)
raise TypeError(msg)
self._dataset = dataset
self._wrapper = wrapper_factory(dataset)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
# `transforms`
# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
# disable all three here to be able to extract the untransformed sample to wrap.
self.transform, dataset.transform = dataset.transform, None
self.target_transform, dataset.target_transform = dataset.target_transform, None
self.transforms, dataset.transforms = dataset.transforms, None
def __getattr__(self, item):
with contextlib.suppress(AttributeError):
return object.__getattribute__(self, item)
return getattr(self._dataset, item)
def __getitem__(self, idx):
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
# of this class
sample = self._dataset[idx]
sample = self._wrapper(sample)
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
if self.transforms is not None:
sample = self.transforms(*sample)
return sample
def __len__(self):
return len(self._dataset)
def raise_not_supported(description):
raise RuntimeError(
f"{description} is currently not supported by this wrapper. "
f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
)
def identity(item):
return item
def pil_image_to_mask(pil_image):
return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0))
def list_of_dicts_to_dict_of_lists(list_of_dicts):
dict_of_lists = defaultdict(list)
for dct in list_of_dicts:
for key, value in dct.items():
dict_of_lists[key].append(value)
return dict(dict_of_lists)
def wrap_target_by_type(target, *, target_types, type_wrappers):
if not isinstance(target, (tuple, list)):
target = [target]
wrapped_target = tuple(
type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
)
if len(wrapped_target) == 1:
wrapped_target = wrapped_target[0]
return wrapped_target
def classification_wrapper_factory(dataset):
return identity
for dataset_cls in [
datasets.Caltech256,
datasets.CIFAR10,
datasets.CIFAR100,
datasets.ImageNet,
datasets.MNIST,
datasets.FashionMNIST,
datasets.GTSRB,
datasets.DatasetFolder,
datasets.ImageFolder,
]:
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
def segmentation_wrapper_factory(dataset):
def wrapper(sample):
image, mask = sample
return image, pil_image_to_mask(mask)
return wrapper
for dataset_cls in [
datasets.VOCSegmentation,
]:
WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
def video_classification_wrapper_factory(dataset):
if dataset.video_clips.output_format == "THWC":
raise RuntimeError(
f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)
def wrapper(sample):
video, audio, label = sample
video = datapoints.Video(video)
return video, audio, label
return wrapper
for dataset_cls in [
datasets.HMDB51,
datasets.Kinetics,
datasets.UCF101,
]:
WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
@WRAPPER_FACTORIES.register(datasets.Caltech101)
def caltech101_wrapper_factory(dataset):
if "annotation" in dataset.target_type:
raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
return classification_wrapper_factory(dataset)
@WRAPPER_FACTORIES.register(datasets.CocoDetection)
def coco_dectection_wrapper_factory(dataset):
def segmentation_to_mask(segmentation, *, spatial_size):
from pycocotools import mask
segmentation = (
mask.frPyObjects(segmentation, *spatial_size)
if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *spatial_size))
)
return torch.from_numpy(mask.decode(segmentation))
def wrapper(sample):
image, target = sample
batched_target = list_of_dicts_to_dict_of_lists(target)
spatial_size = tuple(F.get_spatial_size(image))
batched_target["boxes"] = datapoints.BoundingBox(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
)
batched_target["masks"] = datapoints.Mask(
torch.stack(
[
segmentation_to_mask(segmentation, spatial_size=spatial_size)
for segmentation in batched_target["segmentation"]
]
),
)
batched_target["labels"] = torch.tensor(batched_target["category_id"])
return image, batched_target
return wrapper
VOC_DETECTION_CATEGORIES = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset):
def wrapper(sample):
image, target = sample
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
target["boxes"] = datapoints.BoundingBox(
[
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for bndbox in batched_instances["bndbox"]
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width),
)
target["labels"] = torch.tensor(
[VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.SBDataset)
def sbd_wrapper(dataset):
if dataset.mode == "boundaries":
raise_not_supported("SBDataset with mode='boundaries'")
return segmentation_wrapper_factory(dataset)
@WRAPPER_FACTORIES.register(datasets.CelebA)
def celeba_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
def wrapper(sample):
image, target = sample
target = wrap_target_by_type(
target,
target_types=dataset.target_type,
type_wrappers={
"bbox": lambda item: datapoints.BoundingBox(
item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
),
},
)
return image, target
return wrapper
KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset):
def wrapper(sample):
image, target = sample
if target is not None:
target = list_of_dicts_to_dict_of_lists(target)
target["boxes"] = datapoints.BoundingBox(
target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width)
)
target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]])
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset):
def wrapper(sample):
image, target = sample
if target is not None:
target = wrap_target_by_type(
target,
target_types=dataset._target_types,
type_wrappers={
"segmentation": pil_image_to_mask,
},
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.Cityscapes)
def cityscapes_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
def instance_segmentation_wrapper(mask):
# See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
data = pil_image_to_mask(mask)
masks = []
labels = []
for id in data.unique():
masks.append(data == id)
label = id
if label >= 1_000:
label //= 1_000
labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
def wrapper(sample):
image, target = sample
target = wrap_target_by_type(
target,
target_types=dataset.target_type,
type_wrappers={
"instance": instance_segmentation_wrapper,
"semantic": pil_image_to_mask,
},
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset):
def wrapper(sample):
image, target = sample
if target is not None:
target["bbox"] = datapoints.BoundingBox(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
)
return image, target
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment