_misc.py 11.5 KB
Newer Older
1
import math
2
from typing import List, Optional, Union
3

4
import PIL.Image
5
import torch
6
from torch.nn.functional import conv2d, pad as torch_pad
7

8
from torchvision import datapoints
9
from torchvision.transforms._functional_tensor import _max_value
10
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
11

12
13
from torchvision.utils import _log_api_usage_once

14
15
16
17
18
19
20
from ._utils import (
    _get_kernel,
    _register_explicit_noop,
    _register_kernel_internal,
    _register_unsupported_type,
    is_simple_tensor,
)
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
@_register_unsupported_type(PIL.Image.Image)
def normalize(
    inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT],
    mean: List[float],
    std: List[float],
    inplace: bool = False,
) -> torch.Tensor:
    if not torch.jit.is_scripting():
        _log_api_usage_once(normalize)
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(normalize, type(inpt))
        return kernel(inpt, mean=mean, std=std, inplace=inplace)
    else:
        raise TypeError(
            f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
        )


@_register_kernel_internal(normalize, datapoints.Image)
45
46
47
48
49
50
51
def normalize_image_tensor(
    image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
    if not image.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")

    if image.ndim < 3:
52
        raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.")
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    if isinstance(std, (tuple, list)):
        divzero = not all(std)
    elif isinstance(std, (int, float)):
        divzero = std == 0
    else:
        divzero = False
    if divzero:
        raise ValueError("std evaluated to zero, leading to division by zero.")

    dtype = image.dtype
    device = image.device
    mean = torch.as_tensor(mean, dtype=dtype, device=device)
    std = torch.as_tensor(std, dtype=dtype, device=device)
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)

    if inplace:
        image = image.sub_(mean)
    else:
        image = image.sub(mean)

    return image.div_(std)
78

79

80
@_register_kernel_internal(normalize, datapoints.Video)
81
82
83
84
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
    return normalize_image_tensor(video, mean, std, inplace=inplace)


85
86
87
88
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def gaussian_blur(
    inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT:
89
    if not torch.jit.is_scripting():
90
91
        _log_api_usage_once(gaussian_blur)

92
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
93
94
95
96
97
98
        return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(gaussian_blur, type(inpt))
        return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
    elif isinstance(inpt, PIL.Image.Image):
        return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
99
100
    else:
        raise TypeError(
101
102
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
            f"but got {type(inpt)} instead."
103
        )
104
105


106
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
107
    lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
108
    x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
109
    kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
110
111
112
113
114
115
    return kernel1d


def _get_gaussian_kernel2d(
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
116
117
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
118
119
120
121
    kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
    return kernel2d


122
@_register_kernel_internal(gaussian_blur, datapoints.Image)
123
def gaussian_blur_image_tensor(
124
    image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
125
) -> torch.Tensor:
126
    # TODO: consider deprecating integers from sigma on the future
127
128
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
129
    elif len(kernel_size) != 2:
130
131
132
133
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
134

135
136
    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
137
138
139
140
141
142
143
144
145
146
147
148
149
    else:
        if isinstance(sigma, (list, tuple)):
            length = len(sigma)
            if length == 1:
                s = float(sigma[0])
                sigma = [s, s]
            elif length != 2:
                raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
        elif isinstance(sigma, (int, float)):
            s = float(sigma)
            sigma = [s, s]
        else:
            raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
150
151
152
    for s in sigma:
        if s <= 0.0:
            raise ValueError(f"sigma should have positive values. Got {sigma}")
153

154
155
156
    if image.numel() == 0:
        return image

157
    dtype = image.dtype
158
    shape = image.shape
159
160
161
162
    ndim = image.ndim
    if ndim == 3:
        image = image.unsqueeze(dim=0)
    elif ndim > 4:
163
        image = image.reshape((-1,) + shape[-3:])
164

165
166
167
    fp = torch.is_floating_point(image)
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
    kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
168

169
    output = image if fp else image.to(dtype=torch.float32)
170
171
172

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
173
174
    output = torch_pad(output, padding, mode="reflect")
    output = conv2d(output, kernel, groups=shape[-3])
175

176
177
178
    if ndim == 3:
        output = output.squeeze(dim=0)
    elif ndim > 4:
179
        output = output.reshape(shape)
180

181
182
183
    if not fp:
        output = output.round_().to(dtype=dtype)

184
    return output
185
186


187
@torch.jit.unused
188
def gaussian_blur_image_pil(
189
    image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
190
) -> PIL.Image.Image:
191
    t_img = pil_to_tensor(image)
192
    output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
193
    return to_pil_image(output, mode=image.mode)
194
195


196
@_register_kernel_internal(gaussian_blur, datapoints.Video)
197
198
199
def gaussian_blur_video(
    video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
200
    return gaussian_blur_image_tensor(video, kernel_size, sigma)
201
202


203
204
def to_dtype(
    inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
Philip Meier's avatar
Philip Meier committed
205
) -> datapoints._InputTypeJIT:
206
    if not torch.jit.is_scripting():
207
        _log_api_usage_once(to_dtype)
208

209
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
210
211
212
213
        return to_dtype_image_tensor(inpt, dtype, scale=scale)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(to_dtype, type(inpt))
        return kernel(inpt, dtype, scale=scale)
214
215
    else:
        raise TypeError(
216
            f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
217
        )
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234


def _num_value_bits(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
        return 8
    elif dtype == torch.int8:
        return 7
    elif dtype == torch.int16:
        return 15
    elif dtype == torch.int32:
        return 31
    elif dtype == torch.int64:
        return 63
    else:
        raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


235
@_register_kernel_internal(to_dtype, datapoints.Image)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:

    if image.dtype == dtype:
        return image
    elif not scale:
        return image.to(dtype)

    float_input = image.is_floating_point()
    if torch.jit.is_scripting():
        # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
        float_output = torch.tensor(0, dtype=dtype).is_floating_point()
    else:
        float_output = dtype.is_floating_point

    if float_input:
        # float to float
        if float_output:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

        # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
        # to the integer dtype  is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
        # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # for a detailed analysis.
        # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
        # Instead, we can also multiply by the maximum value plus something close to `1`. See
        # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
        eps = 1e-3
        max_value = float(_max_value(dtype))
        # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
        # discrete set `{0, 1}`.
        return image.mul(max_value + 1.0 - eps).to(dtype)
    else:
        # int to float
        if float_output:
            return image.to(dtype).mul_(1.0 / _max_value(image.dtype))

        # int to int
        num_value_bits_input = _num_value_bits(image.dtype)
        num_value_bits_output = _num_value_bits(dtype)

        if num_value_bits_input > num_value_bits_output:
            return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
        else:
            return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)


# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    return to_dtype_image_tensor(image, dtype=dtype, scale=True)


293
@_register_kernel_internal(to_dtype, datapoints.Video)
294
295
296
297
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
    return to_dtype_image_tensor(video, dtype, scale=scale)


298
299
300
301
302
303
304
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False)
def _to_dtype_tensor_dispatch(
    inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False
) -> datapoints._InputTypeJIT:
    # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
    return inpt.to(dtype)