_misc.py 6.51 KB
Newer Older
1
import math
2
from typing import List, Optional, Union
3

4
import PIL.Image
5
import torch
6
from torch.nn.functional import conv2d, pad as torch_pad
7

8
from torchvision import datapoints
9
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
10

11
12
from torchvision.utils import _log_api_usage_once

13
from ._utils import is_simple_tensor
14

15
16
17
18
19
20
21
22

def normalize_image_tensor(
    image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
    if not image.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")

    if image.ndim < 3:
23
        raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.")
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    if isinstance(std, (tuple, list)):
        divzero = not all(std)
    elif isinstance(std, (int, float)):
        divzero = std == 0
    else:
        divzero = False
    if divzero:
        raise ValueError("std evaluated to zero, leading to division by zero.")

    dtype = image.dtype
    device = image.device
    mean = torch.as_tensor(mean, dtype=dtype, device=device)
    std = torch.as_tensor(std, dtype=dtype, device=device)
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)

    if inplace:
        image = image.sub_(mean)
    else:
        image = image.sub(mean)

    return image.div_(std)
49

50

51
52
53
54
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
    return normalize_image_tensor(video, mean, std, inplace=inplace)


55
def normalize(
Philip Meier's avatar
Philip Meier committed
56
    inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT],
57
58
59
    mean: List[float],
    std: List[float],
    inplace: bool = False,
60
) -> torch.Tensor:
61
    if not torch.jit.is_scripting():
62
        _log_api_usage_once(normalize)
63
64
65
66
67
68
69
70
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
    elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
        return inpt.normalize(mean=mean, std=std, inplace=inplace)
    else:
        raise TypeError(
            f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
        )
71
72


73
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
74
    lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
75
    x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
76
    kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
77
78
79
80
81
82
    return kernel1d


def _get_gaussian_kernel2d(
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
83
84
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
85
86
87
88
    kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
    return kernel2d


89
def gaussian_blur_image_tensor(
90
    image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
91
) -> torch.Tensor:
92
    # TODO: consider deprecating integers from sigma on the future
93
94
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
95
    elif len(kernel_size) != 2:
96
97
98
99
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
100

101
102
    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
103
104
105
106
107
108
109
110
111
112
113
114
115
    else:
        if isinstance(sigma, (list, tuple)):
            length = len(sigma)
            if length == 1:
                s = float(sigma[0])
                sigma = [s, s]
            elif length != 2:
                raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
        elif isinstance(sigma, (int, float)):
            s = float(sigma)
            sigma = [s, s]
        else:
            raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
116
117
118
    for s in sigma:
        if s <= 0.0:
            raise ValueError(f"sigma should have positive values. Got {sigma}")
119

120
121
122
    if image.numel() == 0:
        return image

123
    dtype = image.dtype
124
    shape = image.shape
125
126
127
128
    ndim = image.ndim
    if ndim == 3:
        image = image.unsqueeze(dim=0)
    elif ndim > 4:
129
        image = image.reshape((-1,) + shape[-3:])
130

131
132
133
    fp = torch.is_floating_point(image)
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
    kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
134

135
    output = image if fp else image.to(dtype=torch.float32)
136
137
138

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
139
140
    output = torch_pad(output, padding, mode="reflect")
    output = conv2d(output, kernel, groups=shape[-3])
141

142
143
144
    if ndim == 3:
        output = output.squeeze(dim=0)
    elif ndim > 4:
145
        output = output.reshape(shape)
146

147
148
149
    if not fp:
        output = output.round_().to(dtype=dtype)

150
    return output
151
152


153
@torch.jit.unused
154
def gaussian_blur_image_pil(
155
    image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
156
) -> PIL.Image.Image:
157
    t_img = pil_to_tensor(image)
158
    output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
159
    return to_pil_image(output, mode=image.mode)
160
161


162
163
164
def gaussian_blur_video(
    video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
165
    return gaussian_blur_image_tensor(video, kernel_size, sigma)
166
167


168
def gaussian_blur(
Philip Meier's avatar
Philip Meier committed
169
170
    inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT:
171
172
173
    if not torch.jit.is_scripting():
        _log_api_usage_once(gaussian_blur)

174
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
175
        return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
176
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
177
        return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
178
    elif isinstance(inpt, PIL.Image.Image):
179
        return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
180
181
    else:
        raise TypeError(
182
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
183
184
            f"but got {type(inpt)} instead."
        )