_dataset_wrapper.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# type: ignore

from __future__ import annotations

import contextlib
from collections import defaultdict

import torch
from torch.utils.data import Dataset

11
12
from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

__all__ = ["wrap_dataset_for_transforms_v2"]


# TODO: naming!
def wrap_dataset_for_transforms_v2(dataset):
    return VisionDatasetDatapointWrapper(dataset)


class WrapperFactories(dict):
    def register(self, dataset_cls):
        def decorator(wrapper_factory):
            self[dataset_cls] = wrapper_factory
            return wrapper_factory

        return decorator


# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
# we have access to the dataset instance.
WRAPPER_FACTORIES = WrapperFactories()


class VisionDatasetDatapointWrapper(Dataset):
    def __init__(self, dataset):
        dataset_cls = type(dataset)
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

        if not isinstance(dataset, datasets.VisionDataset):
            raise TypeError(
                f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
                f"but got a '{dataset_cls.__name__}' instead."
            )

        for cls in dataset_cls.mro():
            if cls in WRAPPER_FACTORIES:
                wrapper_factory = WRAPPER_FACTORIES[cls]
                break
            elif cls is datasets.VisionDataset:
                # TODO: If we have documentation on how to do that, put a link in the error message.
                msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
                if dataset_cls in datasets.__dict__.values():
                    msg = (
                        f"{msg} If an automated wrapper for this dataset would be useful for you, "
                        f"please open an issue at https://github.com/pytorch/vision/issues."
                    )
                raise TypeError(msg)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

        self._dataset = dataset
        self._wrapper = wrapper_factory(dataset)

        # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
        # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
        # `transforms`
        # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
        # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
        # disable all three here to be able to extract the untransformed sample to wrap.
        self.transform, dataset.transform = dataset.transform, None
        self.target_transform, dataset.target_transform = dataset.target_transform, None
        self.transforms, dataset.transforms = dataset.transforms, None

    def __getattr__(self, item):
        with contextlib.suppress(AttributeError):
            return object.__getattribute__(self, item)

        return getattr(self._dataset, item)

    def __getitem__(self, idx):
        # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
        # of this class
        sample = self._dataset[idx]

86
        sample = self._wrapper(idx, sample)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
        # or joint (`transforms`), we can access the full functionality through `transforms`
        if self.transforms is not None:
            sample = self.transforms(*sample)

        return sample

    def __len__(self):
        return len(self._dataset)


def raise_not_supported(description):
    raise RuntimeError(
        f"{description} is currently not supported by this wrapper. "
        f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
    )


def identity(item):
    return item


110
111
112
113
114
115
116
def identity_wrapper_factory(dataset):
    def wrapper(idx, sample):
        return sample

    return wrapper


117
def pil_image_to_mask(pil_image):
118
    return datapoints.Mask(pil_image)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


def list_of_dicts_to_dict_of_lists(list_of_dicts):
    dict_of_lists = defaultdict(list)
    for dct in list_of_dicts:
        for key, value in dct.items():
            dict_of_lists[key].append(value)
    return dict(dict_of_lists)


def wrap_target_by_type(target, *, target_types, type_wrappers):
    if not isinstance(target, (tuple, list)):
        target = [target]

    wrapped_target = tuple(
        type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
    )

    if len(wrapped_target) == 1:
        wrapped_target = wrapped_target[0]

    return wrapped_target


def classification_wrapper_factory(dataset):
144
    return identity_wrapper_factory(dataset)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


for dataset_cls in [
    datasets.Caltech256,
    datasets.CIFAR10,
    datasets.CIFAR100,
    datasets.ImageNet,
    datasets.MNIST,
    datasets.FashionMNIST,
    datasets.GTSRB,
    datasets.DatasetFolder,
    datasets.ImageFolder,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)


def segmentation_wrapper_factory(dataset):
162
    def wrapper(idx, sample):
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        image, mask = sample
        return image, pil_image_to_mask(mask)

    return wrapper


for dataset_cls in [
    datasets.VOCSegmentation,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)


def video_classification_wrapper_factory(dataset):
    if dataset.video_clips.output_format == "THWC":
        raise RuntimeError(
            f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
            f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
        )

182
    def wrapper(idx, sample):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        video, audio, label = sample

        video = datapoints.Video(video)

        return video, audio, label

    return wrapper


for dataset_cls in [
    datasets.HMDB51,
    datasets.Kinetics,
    datasets.UCF101,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)


@WRAPPER_FACTORIES.register(datasets.Caltech101)
def caltech101_wrapper_factory(dataset):
    if "annotation" in dataset.target_type:
        raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")

    return classification_wrapper_factory(dataset)


@WRAPPER_FACTORIES.register(datasets.CocoDetection)
def coco_dectection_wrapper_factory(dataset):
    def segmentation_to_mask(segmentation, *, spatial_size):
        from pycocotools import mask

        segmentation = (
            mask.frPyObjects(segmentation, *spatial_size)
            if isinstance(segmentation, dict)
            else mask.merge(mask.frPyObjects(segmentation, *spatial_size))
        )
        return torch.from_numpy(mask.decode(segmentation))

220
221
222
    def wrapper(idx, sample):
        image_id = dataset.ids[idx]

223
224
        image, target = sample

225
226
227
        if not target:
            return image, dict(image_id=image_id)

228
229
        batched_target = list_of_dicts_to_dict_of_lists(target)

230
        batched_target["image_id"] = image_id
231

232
        spatial_size = tuple(F.get_spatial_size(image))
233
234
235
236
237
238
239
        batched_target["boxes"] = F.convert_format_bounding_box(
            datapoints.BoundingBox(
                batched_target["bbox"],
                format=datapoints.BoundingBoxFormat.XYWH,
                spatial_size=spatial_size,
            ),
            new_format=datapoints.BoundingBoxFormat.XYXY,
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        )
        batched_target["masks"] = datapoints.Mask(
            torch.stack(
                [
                    segmentation_to_mask(segmentation, spatial_size=spatial_size)
                    for segmentation in batched_target["segmentation"]
                ]
            ),
        )
        batched_target["labels"] = torch.tensor(batched_target["category_id"])

        return image, batched_target

    return wrapper


256
257
258
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)


259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
VOC_DETECTION_CATEGORIES = [
    "__background__",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]
VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))


@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset):
287
    def wrapper(idx, sample):
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        image, target = sample

        batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])

        target["boxes"] = datapoints.BoundingBox(
            [
                [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
                for bndbox in batched_instances["bndbox"]
            ],
            format=datapoints.BoundingBoxFormat.XYXY,
            spatial_size=(image.height, image.width),
        )
        target["labels"] = torch.tensor(
            [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
        )

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.SBDataset)
def sbd_wrapper(dataset):
    if dataset.mode == "boundaries":
        raise_not_supported("SBDataset with mode='boundaries'")

    return segmentation_wrapper_factory(dataset)


@WRAPPER_FACTORIES.register(datasets.CelebA)
def celeba_wrapper_factory(dataset):
    if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
        raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")

322
    def wrapper(idx, sample):
323
324
325
326
327
328
        image, target = sample

        target = wrap_target_by_type(
            target,
            target_types=dataset.target_type,
            type_wrappers={
329
330
331
332
333
334
335
                "bbox": lambda item: F.convert_format_bounding_box(
                    datapoints.BoundingBox(
                        item,
                        format=datapoints.BoundingBoxFormat.XYWH,
                        spatial_size=(image.height, image.width),
                    ),
                    new_format=datapoints.BoundingBoxFormat.XYXY,
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                ),
            },
        )

        return image, target

    return wrapper


KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))


@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset):
351
    def wrapper(idx, sample):
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        image, target = sample

        if target is not None:
            target = list_of_dicts_to_dict_of_lists(target)

            target["boxes"] = datapoints.BoundingBox(
                target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width)
            )
            target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]])

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset):
369
    def wrapper(idx, sample):
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        image, target = sample

        if target is not None:
            target = wrap_target_by_type(
                target,
                target_types=dataset._target_types,
                type_wrappers={
                    "segmentation": pil_image_to_mask,
                },
            )

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.Cityscapes)
def cityscapes_wrapper_factory(dataset):
    if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
        raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")

    def instance_segmentation_wrapper(mask):
        # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
        data = pil_image_to_mask(mask)
        masks = []
        labels = []
        for id in data.unique():
            masks.append(data == id)
            label = id
            if label >= 1_000:
                label //= 1_000
            labels.append(label)
        return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))

404
    def wrapper(idx, sample):
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        image, target = sample

        target = wrap_target_by_type(
            target,
            target_types=dataset.target_type,
            type_wrappers={
                "instance": instance_segmentation_wrapper,
                "semantic": pil_image_to_mask,
            },
        )

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset):
423
    def wrapper(idx, sample):
424
425
426
        image, target = sample

        if target is not None:
427
428
429
430
431
            target["bbox"] = F.convert_format_bounding_box(
                datapoints.BoundingBox(
                    target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
                ),
                new_format=datapoints.BoundingBoxFormat.XYXY,
432
433
434
435
436
            )

        return image, target

    return wrapper