utils.py 11.3 KB
Newer Older
1
from typing import Union, Optional, List, Tuple, Text, BinaryIO
2
import pathlib
3
4
import torch
import math
5
import warnings
6
import numpy as np
7
from PIL import Image, ImageDraw, ImageFont, ImageColor
8

9
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
10

11

12
@torch.no_grad()
13
def make_grid(
14
    tensor: Union[torch.Tensor, List[torch.Tensor]],
15
16
17
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
18
    value_range: Optional[Tuple[int, int]] = None,
19
20
    scale_each: bool = False,
    pad_value: int = 0,
21
    **kwargs
22
) -> torch.Tensor:
23
24
    """
    Make a grid of images.
25

26
27
28
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
29
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
30
31
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
32
        normalize (bool, optional): If True, shift the image to the range (0, 1),
Tongzhou Wang's avatar
Tongzhou Wang committed
33
            by the min and max values specified by :attr:`range`. Default: ``False``.
34
        value_range (tuple, optional): tuple (min, max) where min and max are numbers,
35
36
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
37
38
39
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
40

41
42
    Returns:
        grid (Tensor): the tensor containing grid of images.
43

44
45
46
    Example:
        See this notebook
        `here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
47
    """
48
49
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
50
51
52
53
54
55
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

    if "range" in kwargs.keys():
        warning = "range will be deprecated, please use value_range instead."
        warnings.warn(warning)
        value_range = kwargs["range"]
56

57
    # if list of tensors, convert to a 4D mini-batch Tensor
58
    if isinstance(tensor, list):
59
        tensor = torch.stack(tensor, dim=0)
60

61
    if tensor.dim() == 2:  # single image H x W
62
        tensor = tensor.unsqueeze(0)
63
    if tensor.dim() == 3:  # single image
64
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
65
            tensor = torch.cat((tensor, tensor, tensor), 0)
66
        tensor = tensor.unsqueeze(0)
67

68
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
69
        tensor = torch.cat((tensor, tensor, tensor), 1)
70
71

    if normalize is True:
72
        tensor = tensor.clone()  # avoid modifying tensor in-place
73
74
75
        if value_range is not None:
            assert isinstance(value_range, tuple), \
                "value_range has to be a tuple (min, max) if specified. min and max are numbers"
76

77
78
79
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))
80

81
82
83
        def norm_range(t, value_range):
            if value_range is not None:
                norm_ip(t, value_range[0], value_range[1])
84
            else:
85
                norm_ip(t, float(t.min()), float(t.max()))
86
87
88

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
89
                norm_range(t, value_range)
90
        else:
91
            norm_range(tensor, value_range)
92

93
    if tensor.size(0) == 1:
94
        return tensor.squeeze(0)
95

96
97
98
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
99
    ymaps = int(math.ceil(float(nmaps) / xmaps))
100
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
101
102
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
103
    k = 0
104
105
    for y in range(ymaps):
        for x in range(xmaps):
106
107
            if k >= nmaps:
                break
108
109
110
111
112
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                2, x * width + padding, width - padding
            ).copy_(tensor[k])
113
114
115
116
            k = k + 1
    return grid


117
@torch.no_grad()
118
def save_image(
119
    tensor: Union[torch.Tensor, List[torch.Tensor]],
120
121
    fp: Union[Text, pathlib.Path, BinaryIO],
    format: Optional[str] = None,
122
    **kwargs
123
) -> None:
124
125
    """
    Save a given Tensor into an image file.
126
127
128
129

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
130
        fp (string or file object): A filename or a file object
131
132
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
133
        **kwargs: Other arguments are documented in ``make_grid``.
134
    """
135
136

    grid = make_grid(tensor, **kwargs)
137
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
138
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
139
    im = Image.fromarray(ndarr)
140
    im.save(fp, format=format)
141
142
143
144
145
146
147
148


@torch.no_grad()
def draw_bounding_boxes(
    image: torch.Tensor,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None,
    colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
149
    fill: Optional[bool] = False,
150
151
152
153
154
155
156
157
    width: int = 1,
    font: Optional[str] = None,
    font_size: int = 10
) -> torch.Tensor:

    """
    Draws bounding boxes on given image.
    The values of the input image should be uint8 between 0 and 255.
158
    If fill is True, Resulting Tensor should be saved as PNG image.
159
160

    Args:
161
        image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
162
        boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
163
164
165
166
167
            the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
            `0 <= ymin < ymax < H`.
        labels (List[str]): List containing the labels of bounding boxes.
        colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
            be represented as `str` or `Tuple[int, int, int]`.
168
        fill (bool): If `True` fills the bounding box with specified color.
169
170
171
172
173
        width (int): Width of bounding box.
        font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
            also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
            `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
        font_size (int): The requested font size in points.
174
175
176
177
178
179
180

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.

    Example:
        See this notebook
        `linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    """

    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")

    ndarr = image.permute(1, 2, 0).numpy()
    img_to_draw = Image.fromarray(ndarr)

    img_boxes = boxes.to(torch.int64).tolist()

195
196
197
198
199
200
    if fill:
        draw = ImageDraw.Draw(img_to_draw, "RGBA")

    else:
        draw = ImageDraw.Draw(img_to_draw)

201
    txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
202
203

    for i, bbox in enumerate(img_boxes):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        if colors is None:
            color = None
        else:
            color = colors[i]

        if fill:
            if color is None:
                fill_color = (255, 255, 255, 100)
            elif isinstance(color, str):
                # This will automatically raise Error if rgb cannot be parsed.
                fill_color = ImageColor.getrgb(color) + (100,)
            elif isinstance(color, tuple):
                fill_color = color + (100,)
            draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
        else:
            draw.rectangle(bbox, width=width, outline=color)
220
221
222
223

        if labels is not None:
            draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

224
    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244


@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
    alpha: float = 0.2,
    colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:

    """
    Draws segmentation masks on given RGB image.
    The values of the input image should be uint8 between 0 and 255.

    Args:
        image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
        masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
        alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
        colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
            be represented as `str` or `Tuple[int, int, int]`.
245
246
247
248
249
250
251

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.

    Example:
        See this notebook
        `attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    """

    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")

    num_masks = masks.size()[0]
    masks = masks.argmax(0)

    if colors is None:
        palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
        colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
        color_arr = (colors_t % 255).numpy().astype("uint8")
    else:
        color_list = []
        for color in colors:
            if isinstance(color, str):
                # This will automatically raise Error if rgb cannot be parsed.
                fill_color = ImageColor.getrgb(color)
                color_list.append(fill_color)
            elif isinstance(color, tuple):
                color_list.append(color)

        color_arr = np.array(color_list).astype("uint8")

    _, h, w = image.size()
    img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
    img_to_draw.putpalette(color_arr)

    img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
    img_to_draw = img_to_draw.permute((2, 0, 1))

    return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)