_misc.py 17.1 KB
Newer Older
1
import warnings
2
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
3
4
5
6
7
8

import PIL.Image

import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

9
from torchvision import transforms as _transforms, tv_tensors
10
11
from torchvision.transforms.v2 import functional as F, Transform

12
from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
13
14


Nicolas Hug's avatar
Nicolas Hug committed
15
# TODO: do we want/need to expose this?
16
17
18
19
20
21
class Identity(Transform):
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return inpt


class Lambda(Transform):
Nicolas Hug's avatar
Nicolas Hug committed
22
    """[BETA] Apply a user-defined function as a transform.
23

24
    .. v2betastatus:: Lambda transform
25
26
27
28
29
30
31

    This transform does not support torchscript.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

Philip Meier's avatar
Philip Meier committed
32
33
    _transformed_types = (object,)

34
35
36
    def __init__(self, lambd: Callable[[Any], Any], *types: Type):
        super().__init__()
        self.lambd = lambd
Philip Meier's avatar
Philip Meier committed
37
        self.types = types or self._transformed_types
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        if isinstance(inpt, self.types):
            return self.lambd(inpt)
        else:
            return inpt

    def extra_repr(self) -> str:
        extras = []
        name = getattr(self.lambd, "__name__", None)
        if name:
            extras.append(name)
        extras.append(f"types={[type.__name__ for type in self.types]}")
        return ", ".join(extras)


class LinearTransformation(Transform):
Nicolas Hug's avatar
Nicolas Hug committed
55
    """[BETA] Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline.
56

57
    .. v2betastatus:: LinearTransformation transform
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    This transform does not support PIL Image.
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
    original shape.

    Applications:
        whitening transformation: Suppose X is a column vector zero-centered data.
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
        mean_vector (Tensor): tensor [D], D = C x H x W
    """

75
76
    _v1_transform_cls = _transforms.LinearTransformation

77
    _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
        super().__init__()
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError(
                "transformation_matrix should be square. Got "
                f"{tuple(transformation_matrix.size())} rectangular matrix."
            )

        if mean_vector.size(0) != transformation_matrix.size(0):
            raise ValueError(
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
            )

        if transformation_matrix.device != mean_vector.device:
            raise ValueError(
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
            )

        if transformation_matrix.dtype != mean_vector.dtype:
            raise ValueError(
                f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
            )

        self.transformation_matrix = transformation_matrix
        self.mean_vector = mean_vector

    def _check_inputs(self, sample: Any) -> Any:
        if has_any(sample, PIL.Image.Image):
108
            raise TypeError(f"{type(self).__name__}() does not support PIL images.")
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        shape = inpt.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
            )

        if inpt.device.type != self.mean_vector.device.type:
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
                f"Got {inpt.device} vs {self.mean_vector.device}"
            )

        flat_inpt = inpt.reshape(-1, n) - self.mean_vector

        transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
        output = torch.mm(flat_inpt, transformation_matrix)
        output = output.reshape(shape)

132
133
        if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
            output = tv_tensors.wrap(output, like=inpt)
134
135
136
137
        return output


class Normalize(Transform):
Nicolas Hug's avatar
Nicolas Hug committed
138
    """[BETA] Normalize a tensor image or video with mean and standard deviation.
139

140
    .. v2betastatus:: Normalize transform
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    This transform does not support PIL Image.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

158
159
160
161
162
163
164
165
166
167
168
169
    _v1_transform_cls = _transforms.Normalize

    def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
        super().__init__()
        self.mean = list(mean)
        self.std = list(std)
        self.inplace = inplace

    def _check_inputs(self, sample: Any) -> Any:
        if has_any(sample, PIL.Image.Image):
            raise TypeError(f"{type(self).__name__}() does not support PIL images.")

170
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
171
        return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace)
172
173
174


class GaussianBlur(Transform):
175
176
    """[BETA] Blurs image with randomly chosen Gaussian blur.

177
    .. v2betastatus:: GausssianBlur transform
178

Nicolas Hug's avatar
Nicolas Hug committed
179
    If the input is a Tensor, it is expected
180
181
182
183
184
185
186
187
188
189
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.
    """

190
191
192
193
194
195
196
197
198
199
200
    _v1_transform_cls = _transforms.GaussianBlur

    def __init__(
        self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
    ) -> None:
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

201
        self.sigma = _setup_number_or_seq(sigma, "sigma")
202

203
204
        if not 0.0 < self.sigma[0] <= self.sigma[1]:
            raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}")
205
206
207
208
209
210

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
        return dict(sigma=[sigma, sigma])

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
211
        return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)
212
213
214


class ToDtype(Transform):
215
    """[BETA] Converts the input to a specific dtype, optionally scaling the values for images or videos.
Nicolas Hug's avatar
Nicolas Hug committed
216

217
    .. v2betastatus:: ToDtype transform
Nicolas Hug's avatar
Nicolas Hug committed
218

219
220
221
    .. note::
        ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.

Nicolas Hug's avatar
Nicolas Hug committed
222
    Args:
223
        dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to.
224
225
            If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
            to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
226
227
228
            A dict can be passed to specify per-tv_tensor conversions, e.g.
            ``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others"
            key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion.
229
230
        scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
            Default: ``False``.
Nicolas Hug's avatar
Nicolas Hug committed
231
232
    """

233
234
    _transformed_types = (torch.Tensor,)

235
236
237
    def __init__(
        self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False
    ) -> None:
238
        super().__init__()
239
240
241
242
243
244
245

        if not isinstance(dtype, (dict, torch.dtype)):
            raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead")

        if (
            isinstance(dtype, dict)
            and torch.Tensor in dtype
246
            and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video])
247
        ):
248
            warnings.warn(
249
                "Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
250
                "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
251
                "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
252
253
            )
        self.dtype = dtype
254
        self.scale = scale
255
256

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
257
258
259
        if isinstance(self.dtype, torch.dtype):
            # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
            # is a simple torch.dtype
260
            if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
261
262
263
264
265
266
267
268
269
270
271
272
273
                return inpt

            dtype: Optional[torch.dtype] = self.dtype
        elif type(inpt) in self.dtype:
            dtype = self.dtype[type(inpt)]
        elif "others" in self.dtype:
            dtype = self.dtype["others"]
        else:
            raise ValueError(
                f"No dtype was specified for type {type(inpt)}. "
                "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
                "If you're passing a dict as dtype, "
                'you can use "others" as a catch-all key '
274
                'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
275
276
            )

277
        supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
278
        if dtype is None:
279
280
281
282
            if self.scale and supports_scaling:
                warnings.warn(
                    "scale was set to True but no dtype was specified for images or videos: no scaling will be done."
                )
283
            return inpt
284

285
        return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale)
286
287


288
class ConvertImageDtype(Transform):
Nicolas Hug's avatar
Nicolas Hug committed
289
290
291
    """[BETA] [DEPRECATED] Use ``v2.ToDtype(dtype, scale=True)`` instead.

    Convert input image to the given ``dtype`` and scale the values accordingly.
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

    .. v2betastatus:: ConvertImageDtype transform

    .. warning::
        Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.

    This function does not support PIL Image.

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    _v1_transform_cls = _transforms.ConvertImageDtype

    def __init__(self, dtype: torch.dtype = torch.float32) -> None:
        super().__init__()
        self.dtype = dtype

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
322
        return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True)
323
324


325
class SanitizeBoundingBoxes(Transform):
Nicolas Hug's avatar
Nicolas Hug committed
326
327
    """[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.

328
    .. v2betastatus:: SanitizeBoundingBoxes transform
Nicolas Hug's avatar
Nicolas Hug committed
329
330
331
332
333

    This transform removes bounding boxes and their associated labels/masks that:

    - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
    - have any coordinate outside of their corresponding image. You may want to
334
      call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
Nicolas Hug's avatar
Nicolas Hug committed
335
336
337
338
339
340
341
342
343
344
345

    It is recommended to call it at the end of a pipeline, before passing the
    input to the models. It is critical to call this transform if
    :class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
    If you want to be extra careful, you may call it after all transforms that
    may modify bounding boxes but once at the end should be enough in most
    cases.

    Args:
        min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
        labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
346
            By default, this will try to find a "labels" key in the input (case-insensitive), if
Nicolas Hug's avatar
Nicolas Hug committed
347
348
            the input is a dict or it is a tuple whose second element is a dict.
            This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
349
350
            It can also be a callable that takes the same input
            as the transform, and returns the labels.
Nicolas Hug's avatar
Nicolas Hug committed
351
    """
352
353
354
355
356
357
358
359
360
361
362
363
364

    def __init__(
        self,
        min_size: float = 1.0,
        labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
    ) -> None:
        super().__init__()

        if min_size < 1:
            raise ValueError(f"min_size must be >= 1, got {min_size}.")
        self.min_size = min_size

        self.labels_getter = labels_getter
365
        self._labels_getter = _parse_labels_getter(labels_getter)
366
367
368
369

    def forward(self, *inputs: Any) -> Any:
        inputs = inputs if len(inputs) > 1 else inputs[0]

370
371
372
373
374
        labels = self._labels_getter(inputs)
        if labels is not None and not isinstance(labels, torch.Tensor):
            raise ValueError(
                f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead."
            )
375
376

        flat_inputs, spec = tree_flatten(inputs)
377
        boxes = get_bounding_boxes(flat_inputs)
378
379
380
381
382
383
384

        if labels is not None and boxes.shape[0] != labels.shape[0]:
            raise ValueError(
                f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
            )

        boxes = cast(
385
            tv_tensors.BoundingBoxes,
Nicolas Hug's avatar
Nicolas Hug committed
386
            F.convert_bounding_box_format(
387
                boxes,
388
                new_format=tv_tensors.BoundingBoxFormat.XYXY,
389
390
391
            ),
        )
        ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
392
        valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
393
394
        # TODO: Do we really need to check for out of bounds here? All
        # transforms should be clamping anyway, so this should never happen?
Philip Meier's avatar
Philip Meier committed
395
        image_h, image_w = boxes.canvas_size
396
397
        valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
        valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
398

399
        params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
400
401
        flat_outputs = [
            # Even-though it may look like we're transforming all inputs, we don't:
402
            # _transform() will only care about BoundingBoxeses and the labels
403
404
405
406
407
408
409
            self._transform(inpt, params)
            for inpt in flat_inputs
        ]

        return tree_unflatten(flat_outputs, spec)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
410
        is_label = inpt is not None and inpt is params["labels"]
411
        is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
412

413
        if not (is_label or is_bounding_boxes_or_mask):
414
            return inpt
415

416
417
418
419
420
        output = inpt[params["valid"]]

        if is_label:
            return output

421
        return tv_tensors.wrap(output, like=inpt)