_misc.py 10.5 KB
Newer Older
1
import math
2
from typing import List, Optional
3

4
import PIL.Image
5
import torch
6
from torch.nn.functional import conv2d, pad as torch_pad
7

8
from torchvision import datapoints
9
from torchvision.transforms._functional_tensor import _max_value
10
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
11

12
13
from torchvision.utils import _log_api_usage_once

14
from ._utils import _get_kernel, _register_kernel_internal
15

16

17
def normalize(
18
    inpt: torch.Tensor,
19
20
21
22
    mean: List[float],
    std: List[float],
    inplace: bool = False,
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
23
    """[BETA] See :class:`~torchvision.transforms.v2.Normalize` for details."""
24
    if torch.jit.is_scripting():
25
        return normalize_image(inpt, mean=mean, std=std, inplace=inplace)
26
27
28
29
30

    _log_api_usage_once(normalize)

    kernel = _get_kernel(normalize, type(inpt))
    return kernel(inpt, mean=mean, std=std, inplace=inplace)
31
32


33
@_register_kernel_internal(normalize, torch.Tensor)
34
@_register_kernel_internal(normalize, datapoints.Image)
35
def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
36
37
38
39
    if not image.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")

    if image.ndim < 3:
40
        raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.")
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    if isinstance(std, (tuple, list)):
        divzero = not all(std)
    elif isinstance(std, (int, float)):
        divzero = std == 0
    else:
        divzero = False
    if divzero:
        raise ValueError("std evaluated to zero, leading to division by zero.")

    dtype = image.dtype
    device = image.device
    mean = torch.as_tensor(mean, dtype=dtype, device=device)
    std = torch.as_tensor(std, dtype=dtype, device=device)
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)

    if inplace:
        image = image.sub_(mean)
    else:
        image = image.sub(mean)

    return image.div_(std)
66

67

68
@_register_kernel_internal(normalize, datapoints.Video)
69
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
70
    return normalize_image(video, mean, std, inplace=inplace)
71
72


73
def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
74
    """[BETA] See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
75
    if torch.jit.is_scripting():
76
        return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma)
77
78
79
80
81

    _log_api_usage_once(gaussian_blur)

    kernel = _get_kernel(gaussian_blur, type(inpt))
    return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
82
83


84
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
85
    lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
86
    x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
87
    kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
88
89
90
91
92
93
    return kernel1d


def _get_gaussian_kernel2d(
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
94
95
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
96
97
98
99
    kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
    return kernel2d


100
@_register_kernel_internal(gaussian_blur, torch.Tensor)
101
@_register_kernel_internal(gaussian_blur, datapoints.Image)
102
def gaussian_blur_image(
103
    image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
104
) -> torch.Tensor:
105
    # TODO: consider deprecating integers from sigma on the future
106
107
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
108
    elif len(kernel_size) != 2:
109
110
111
112
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
113

114
115
    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
116
117
118
119
120
121
122
123
124
125
126
127
128
    else:
        if isinstance(sigma, (list, tuple)):
            length = len(sigma)
            if length == 1:
                s = float(sigma[0])
                sigma = [s, s]
            elif length != 2:
                raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
        elif isinstance(sigma, (int, float)):
            s = float(sigma)
            sigma = [s, s]
        else:
            raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
129
130
131
    for s in sigma:
        if s <= 0.0:
            raise ValueError(f"sigma should have positive values. Got {sigma}")
132

133
134
135
    if image.numel() == 0:
        return image

136
    dtype = image.dtype
137
    shape = image.shape
138
139
140
141
    ndim = image.ndim
    if ndim == 3:
        image = image.unsqueeze(dim=0)
    elif ndim > 4:
142
        image = image.reshape((-1,) + shape[-3:])
143

144
145
146
    fp = torch.is_floating_point(image)
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
    kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
147

148
    output = image if fp else image.to(dtype=torch.float32)
149
150
151

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
152
153
    output = torch_pad(output, padding, mode="reflect")
    output = conv2d(output, kernel, groups=shape[-3])
154

155
156
157
    if ndim == 3:
        output = output.squeeze(dim=0)
    elif ndim > 4:
158
        output = output.reshape(shape)
159

160
161
162
    if not fp:
        output = output.round_().to(dtype=dtype)

163
    return output
164
165


166
@_register_kernel_internal(gaussian_blur, PIL.Image.Image)
167
def _gaussian_blur_image_pil(
168
    image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
169
) -> PIL.Image.Image:
170
    t_img = pil_to_tensor(image)
171
    output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma)
172
    return to_pil_image(output, mode=image.mode)
173
174


175
@_register_kernel_internal(gaussian_blur, datapoints.Video)
176
177
178
def gaussian_blur_video(
    video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
179
    return gaussian_blur_image(video, kernel_size, sigma)
180
181


182
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
183
    """[BETA] See :func:`~torchvision.transforms.v2.ToDtype` for details."""
184
    if torch.jit.is_scripting():
185
        return to_dtype_image(inpt, dtype=dtype, scale=scale)
186
187
188
189
190

    _log_api_usage_once(to_dtype)

    kernel = _get_kernel(to_dtype, type(inpt))
    return kernel(inpt, dtype=dtype, scale=scale)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207


def _num_value_bits(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
        return 8
    elif dtype == torch.int8:
        return 7
    elif dtype == torch.int16:
        return 15
    elif dtype == torch.int32:
        return 31
    elif dtype == torch.int64:
        return 63
    else:
        raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


208
@_register_kernel_internal(to_dtype, torch.Tensor)
209
@_register_kernel_internal(to_dtype, datapoints.Image)
210
def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    if image.dtype == dtype:
        return image
    elif not scale:
        return image.to(dtype)

    float_input = image.is_floating_point()
    if torch.jit.is_scripting():
        # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
        float_output = torch.tensor(0, dtype=dtype).is_floating_point()
    else:
        float_output = dtype.is_floating_point

    if float_input:
        # float to float
        if float_output:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

        # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
        # to the integer dtype  is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
        # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # for a detailed analysis.
        # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
        # Instead, we can also multiply by the maximum value plus something close to `1`. See
        # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
        eps = 1e-3
        max_value = float(_max_value(dtype))
        # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
        # discrete set `{0, 1}`.
        return image.mul(max_value + 1.0 - eps).to(dtype)
    else:
        # int to float
        if float_output:
            return image.to(dtype).mul_(1.0 / _max_value(image.dtype))

        # int to int
        num_value_bits_input = _num_value_bits(image.dtype)
        num_value_bits_output = _num_value_bits(dtype)

        if num_value_bits_input > num_value_bits_output:
            return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
        else:
            return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)


# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
264
    """[BETA] [DEPRECATED] Use to_dtype() instead."""
265
    return to_dtype_image(image, dtype=dtype, scale=True)
266
267


268
@_register_kernel_internal(to_dtype, datapoints.Video)
269
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
270
    return to_dtype_image(video, dtype, scale=scale)
271
272


273
274
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False)
275
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
276
277
    # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
    return inpt.to(dtype)