_dataset_wrapper.py 23.1 KB
Newer Older
1
2
3
4
# type: ignore

from __future__ import annotations

5
6
import collections.abc

7
8
9
10
11
import contextlib
from collections import defaultdict

import torch

12
13
from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F
14
15
16
17

__all__ = ["wrap_dataset_for_transforms_v2"]


18
def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.

    .. v2betastatus:: wrap_dataset_for_transforms_v2 function

    Example:
        >>> dataset = torchvision.datasets.CocoDetection(...)
        >>> dataset = wrap_dataset_for_transforms_v2(dataset)

    .. note::

       For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
       configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
       to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.

    The dataset samples are wrapped according to the description below.

    Special cases:

        * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
          returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
          ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
40
          The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
41
          ``"image_id"``, ``"boxes"``, and ``"labels"``.
42
43
        * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
          the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
44
          preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
45
        * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
46
          coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
47
48
49
        * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
          dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
          in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
50
          omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
51
52
53
54
55
56
57
        * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
          :class:`~torchvision.datapoints.Mask` datapoint.
        * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
          :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by
          a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
          ``"labels"``.
        * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
58
          coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
59
60
61
62
63
64
65
66

    Image classification datasets

        This wrapper is a no-op for image classification datasets, since they were already fully supported by
        :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.

    Segmentation datasets

67
        Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
68
69
70
71
72
        :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
        segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).

    Video classification datasets

73
        Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
74
75
76
77
78
79
80
81
82
83
        :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
        :class:`~torchvision.datapoints.Video` while leaving the other items as is.

        .. note::

            Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
            ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.

    Args:
        dataset: the dataset instance to wrap for compatibility with transforms v2.
84
85
86
87
88
        target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
            specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
            fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
            :class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
            :class:`~torchvision.datasets.WIDERFace`. See above for details.
89
    """
90
91
92
93
94
95
96
97
98
99
    if not (
        target_keys is None
        or target_keys == "all"
        or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys))
    ):
        raise ValueError(
            f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
            f"but got {target_keys}"
        )

100
101
102
103
104
105
106
107
108
109
    # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
    # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
    # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
    # while we can still inject everything that we need.
    wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
    # Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
    # VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
    # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
    # have the existing instance as attribute on the new object.
    return wrapped_dataset_cls(dataset, target_keys)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127


class WrapperFactories(dict):
    def register(self, dataset_cls):
        def decorator(wrapper_factory):
            self[dataset_cls] = wrapper_factory
            return wrapper_factory

        return decorator


# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
# we have access to the dataset instance.
WRAPPER_FACTORIES = WrapperFactories()


128
class VisionDatasetDatapointWrapper:
129
    def __init__(self, dataset, target_keys):
130
        dataset_cls = type(dataset)
131
132
133
134

        if not isinstance(dataset, datasets.VisionDataset):
            raise TypeError(
                f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
135
136
137
                f"but got a '{dataset_cls.__name__}' instead.\n"
                f"For an example of how to perform the wrapping for custom datasets, see\n\n"
                "https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
138
139
140
141
142
            )

        for cls in dataset_cls.mro():
            if cls in WRAPPER_FACTORIES:
                wrapper_factory = WRAPPER_FACTORIES[cls]
143
144
145
146
147
148
149
150
151
152
                if target_keys is not None and cls not in {
                    datasets.CocoDetection,
                    datasets.VOCDetection,
                    datasets.Kitti,
                    datasets.WIDERFace,
                }:
                    raise ValueError(
                        f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
                        f"and `WIDERFace`, but got {cls.__name__}."
                    )
153
154
155
156
157
158
159
160
161
162
                break
            elif cls is datasets.VisionDataset:
                # TODO: If we have documentation on how to do that, put a link in the error message.
                msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
                if dataset_cls in datasets.__dict__.values():
                    msg = (
                        f"{msg} If an automated wrapper for this dataset would be useful for you, "
                        f"please open an issue at https://github.com/pytorch/vision/issues."
                    )
                raise TypeError(msg)
163
164

        self._dataset = dataset
165
        self._target_keys = target_keys
166
        self._wrapper = wrapper_factory(dataset, target_keys)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
        # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
        # `transforms`
        # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
        # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
        # disable all three here to be able to extract the untransformed sample to wrap.
        self.transform, dataset.transform = dataset.transform, None
        self.target_transform, dataset.target_transform = dataset.target_transform, None
        self.transforms, dataset.transforms = dataset.transforms, None

    def __getattr__(self, item):
        with contextlib.suppress(AttributeError):
            return object.__getattribute__(self, item)

        return getattr(self._dataset, item)

    def __getitem__(self, idx):
        # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
        # of this class
        sample = self._dataset[idx]

189
        sample = self._wrapper(idx, sample)
190
191
192
193
194
195
196
197
198
199
200

        # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
        # or joint (`transforms`), we can access the full functionality through `transforms`
        if self.transforms is not None:
            sample = self.transforms(*sample)

        return sample

    def __len__(self):
        return len(self._dataset)

201
202
203
    def __reduce__(self):
        return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)

204
205
206
207
208
209
210
211
212
213
214
215

def raise_not_supported(description):
    raise RuntimeError(
        f"{description} is currently not supported by this wrapper. "
        f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
    )


def identity(item):
    return item


216
def identity_wrapper_factory(dataset, target_keys):
217
218
219
220
221
222
    def wrapper(idx, sample):
        return sample

    return wrapper


223
def pil_image_to_mask(pil_image):
224
    return datapoints.Mask(pil_image)
225
226


227
228
229
230
231
232
233
234
235
236
237
238
239
240
def parse_target_keys(target_keys, *, available, default):
    if target_keys is None:
        target_keys = default
    if target_keys == "all":
        target_keys = available
    else:
        target_keys = set(target_keys)
        extra = target_keys - available
        if extra:
            raise ValueError(f"Target keys {sorted(extra)} are not available")

    return target_keys


241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def list_of_dicts_to_dict_of_lists(list_of_dicts):
    dict_of_lists = defaultdict(list)
    for dct in list_of_dicts:
        for key, value in dct.items():
            dict_of_lists[key].append(value)
    return dict(dict_of_lists)


def wrap_target_by_type(target, *, target_types, type_wrappers):
    if not isinstance(target, (tuple, list)):
        target = [target]

    wrapped_target = tuple(
        type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
    )

    if len(wrapped_target) == 1:
        wrapped_target = wrapped_target[0]

    return wrapped_target


263
264
def classification_wrapper_factory(dataset, target_keys):
    return identity_wrapper_factory(dataset, target_keys)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


for dataset_cls in [
    datasets.Caltech256,
    datasets.CIFAR10,
    datasets.CIFAR100,
    datasets.ImageNet,
    datasets.MNIST,
    datasets.FashionMNIST,
    datasets.GTSRB,
    datasets.DatasetFolder,
    datasets.ImageFolder,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)


281
def segmentation_wrapper_factory(dataset, target_keys):
282
    def wrapper(idx, sample):
283
284
285
286
287
288
289
290
291
292
293
294
        image, mask = sample
        return image, pil_image_to_mask(mask)

    return wrapper


for dataset_cls in [
    datasets.VOCSegmentation,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)


295
def video_classification_wrapper_factory(dataset, target_keys):
296
297
298
299
300
301
    if dataset.video_clips.output_format == "THWC":
        raise RuntimeError(
            f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
            f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
        )

302
    def wrapper(idx, sample):
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        video, audio, label = sample

        video = datapoints.Video(video)

        return video, audio, label

    return wrapper


for dataset_cls in [
    datasets.HMDB51,
    datasets.Kinetics,
    datasets.UCF101,
]:
    WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)


@WRAPPER_FACTORIES.register(datasets.Caltech101)
321
def caltech101_wrapper_factory(dataset, target_keys):
322
323
324
    if "annotation" in dataset.target_type:
        raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")

325
    return classification_wrapper_factory(dataset, target_keys)
326
327
328


@WRAPPER_FACTORIES.register(datasets.CocoDetection)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def coco_dectection_wrapper_factory(dataset, target_keys):
    target_keys = parse_target_keys(
        target_keys,
        available={
            # native
            "segmentation",
            "area",
            "iscrowd",
            "image_id",
            "bbox",
            "category_id",
            # added by the wrapper
            "boxes",
            "masks",
            "labels",
        },
345
        default={"image_id", "boxes", "labels"},
346
347
    )

Philip Meier's avatar
Philip Meier committed
348
    def segmentation_to_mask(segmentation, *, canvas_size):
349
350
351
        from pycocotools import mask

        segmentation = (
Philip Meier's avatar
Philip Meier committed
352
            mask.frPyObjects(segmentation, *canvas_size)
353
            if isinstance(segmentation, dict)
Philip Meier's avatar
Philip Meier committed
354
            else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
355
356
357
        )
        return torch.from_numpy(mask.decode(segmentation))

358
359
360
    def wrapper(idx, sample):
        image_id = dataset.ids[idx]

361
362
        image, target = sample

363
364
365
        if not target:
            return image, dict(image_id=image_id)

Philip Meier's avatar
Philip Meier committed
366
        canvas_size = tuple(F.get_size(image))
367

368
        batched_target = list_of_dicts_to_dict_of_lists(target)
369
        target = {}
370

371
372
        if "image_id" in target_keys:
            target["image_id"] = image_id
373

374
        if "boxes" in target_keys:
Nicolas Hug's avatar
Nicolas Hug committed
375
            target["boxes"] = F.convert_bounding_box_format(
376
                datapoints.BoundingBoxes(
377
378
                    batched_target["bbox"],
                    format=datapoints.BoundingBoxFormat.XYWH,
Philip Meier's avatar
Philip Meier committed
379
                    canvas_size=canvas_size,
380
381
382
383
384
385
386
387
                ),
                new_format=datapoints.BoundingBoxFormat.XYXY,
            )

        if "masks" in target_keys:
            target["masks"] = datapoints.Mask(
                torch.stack(
                    [
Philip Meier's avatar
Philip Meier committed
388
                        segmentation_to_mask(segmentation, canvas_size=canvas_size)
389
390
391
392
393
394
395
                        for segmentation in batched_target["segmentation"]
                    ]
                ),
            )

        if "labels" in target_keys:
            target["labels"] = torch.tensor(batched_target["category_id"])
396

397
398
399
400
        for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
            target[target_key] = batched_target[target_key]

        return image, target
401
402
403
404

    return wrapper


405
406
407
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)


408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
VOC_DETECTION_CATEGORIES = [
    "__background__",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]
VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))


@WRAPPER_FACTORIES.register(datasets.VOCDetection)
435
436
437
438
439
440
441
442
443
444
445
446
447
def voc_detection_wrapper_factory(dataset, target_keys):
    target_keys = parse_target_keys(
        target_keys,
        available={
            # native
            "annotation",
            # added by the wrapper
            "boxes",
            "labels",
        },
        default={"boxes", "labels"},
    )

448
    def wrapper(idx, sample):
449
450
451
452
        image, target = sample

        batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])

453
454
455
456
        if "annotation" not in target_keys:
            target = {}

        if "boxes" in target_keys:
457
            target["boxes"] = datapoints.BoundingBoxes(
458
459
460
461
462
                [
                    [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
                    for bndbox in batched_instances["bndbox"]
                ],
                format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
463
                canvas_size=(image.height, image.width),
464
465
466
467
468
469
            )

        if "labels" in target_keys:
            target["labels"] = torch.tensor(
                [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
            )
470
471
472
473
474
475
476

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.SBDataset)
477
def sbd_wrapper(dataset, target_keys):
478
479
480
    if dataset.mode == "boundaries":
        raise_not_supported("SBDataset with mode='boundaries'")

481
    return segmentation_wrapper_factory(dataset, target_keys)
482
483
484


@WRAPPER_FACTORIES.register(datasets.CelebA)
485
def celeba_wrapper_factory(dataset, target_keys):
486
487
488
    if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
        raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")

489
    def wrapper(idx, sample):
490
491
492
493
494
495
        image, target = sample

        target = wrap_target_by_type(
            target,
            target_types=dataset.target_type,
            type_wrappers={
Nicolas Hug's avatar
Nicolas Hug committed
496
                "bbox": lambda item: F.convert_bounding_box_format(
497
                    datapoints.BoundingBoxes(
498
499
                        item,
                        format=datapoints.BoundingBoxFormat.XYWH,
Philip Meier's avatar
Philip Meier committed
500
                        canvas_size=(image.height, image.width),
501
502
                    ),
                    new_format=datapoints.BoundingBoxFormat.XYXY,
503
504
505
506
507
508
509
510
511
512
513
514
515
516
                ),
            },
        )

        return image, target

    return wrapper


KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))


@WRAPPER_FACTORIES.register(datasets.Kitti)
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def kitti_wrapper_factory(dataset, target_keys):
    target_keys = parse_target_keys(
        target_keys,
        available={
            # native
            "type",
            "truncated",
            "occluded",
            "alpha",
            "bbox",
            "dimensions",
            "location",
            "rotation_y",
            # added by the wrapper
            "boxes",
            "labels",
        },
        default={"boxes", "labels"},
    )

537
    def wrapper(idx, sample):
538
539
        image, target = sample

540
541
542
543
544
        if target is None:
            return image, target

        batched_target = list_of_dicts_to_dict_of_lists(target)
        target = {}
545

546
        if "boxes" in target_keys:
547
            target["boxes"] = datapoints.BoundingBoxes(
548
549
                batched_target["bbox"],
                format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
550
                canvas_size=(image.height, image.width),
551
            )
552
553
554
555
556
557

        if "labels" in target_keys:
            target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])

        for target_key in target_keys - {"boxes", "labels"}:
            target[target_key] = batched_target[target_key]
558
559
560
561
562
563
564

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
565
def oxford_iiit_pet_wrapper_factor(dataset, target_keys):
566
    def wrapper(idx, sample):
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        image, target = sample

        if target is not None:
            target = wrap_target_by_type(
                target,
                target_types=dataset._target_types,
                type_wrappers={
                    "segmentation": pil_image_to_mask,
                },
            )

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.Cityscapes)
584
def cityscapes_wrapper_factory(dataset, target_keys):
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
        raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")

    def instance_segmentation_wrapper(mask):
        # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
        data = pil_image_to_mask(mask)
        masks = []
        labels = []
        for id in data.unique():
            masks.append(data == id)
            label = id
            if label >= 1_000:
                label //= 1_000
            labels.append(label)
        return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))

601
    def wrapper(idx, sample):
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        image, target = sample

        target = wrap_target_by_type(
            target,
            target_types=dataset.target_type,
            type_wrappers={
                "instance": instance_segmentation_wrapper,
                "semantic": pil_image_to_mask,
            },
        )

        return image, target

    return wrapper


@WRAPPER_FACTORIES.register(datasets.WIDERFace)
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
def widerface_wrapper(dataset, target_keys):
    target_keys = parse_target_keys(
        target_keys,
        available={
            "bbox",
            "blur",
            "expression",
            "illumination",
            "occlusion",
            "pose",
            "invalid",
        },
        default="all",
    )

634
    def wrapper(idx, sample):
635
636
        image, target = sample

637
638
639
640
641
642
        if target is None:
            return image, target

        target = {key: target[key] for key in target_keys}

        if "bbox" in target_keys:
Nicolas Hug's avatar
Nicolas Hug committed
643
            target["bbox"] = F.convert_bounding_box_format(
644
                datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
645
                    target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
646
647
                ),
                new_format=datapoints.BoundingBoxFormat.XYXY,
648
649
650
651
652
            )

        return image, target

    return wrapper