utils.py 15.2 KB
Newer Older
1
import math
2
import pathlib
3
import warnings
4
from typing import Union, Optional, List, Tuple, BinaryIO
5

6
import numpy as np
7
import torch
8
from PIL import Image, ImageDraw, ImageFont, ImageColor
9

10
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"]
11

12

13
@torch.no_grad()
14
def make_grid(
15
    tensor: Union[torch.Tensor, List[torch.Tensor]],
16
17
18
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
19
    value_range: Optional[Tuple[int, int]] = None,
20
21
    scale_each: bool = False,
    pad_value: int = 0,
22
    **kwargs,
23
) -> torch.Tensor:
24
25
    """
    Make a grid of images.
26

27
28
29
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
30
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
31
32
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
33
        normalize (bool, optional): If True, shift the image to the range (0, 1),
34
            by the min and max values specified by ``value_range``. Default: ``False``.
35
        value_range (tuple, optional): tuple (min, max) where min and max are numbers,
36
37
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
38
39
40
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
41

42
43
    Returns:
        grid (Tensor): the tensor containing grid of images.
44
    """
45
    _log_api_usage_once("utils", "make_grid")
46
47
    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
48
49
50
51
52

    if "range" in kwargs.keys():
        warning = "range will be deprecated, please use value_range instead."
        warnings.warn(warning)
        value_range = kwargs["range"]
53

54
    # if list of tensors, convert to a 4D mini-batch Tensor
55
    if isinstance(tensor, list):
56
        tensor = torch.stack(tensor, dim=0)
57

58
    if tensor.dim() == 2:  # single image H x W
59
        tensor = tensor.unsqueeze(0)
60
    if tensor.dim() == 3:  # single image
61
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
62
            tensor = torch.cat((tensor, tensor, tensor), 0)
63
        tensor = tensor.unsqueeze(0)
64

65
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
66
        tensor = torch.cat((tensor, tensor, tensor), 1)
67
68

    if normalize is True:
69
        tensor = tensor.clone()  # avoid modifying tensor in-place
70
        if value_range is not None:
71
72
73
            assert isinstance(
                value_range, tuple
            ), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
74

75
76
77
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))
78

79
80
81
        def norm_range(t, value_range):
            if value_range is not None:
                norm_ip(t, value_range[0], value_range[1])
82
            else:
83
                norm_ip(t, float(t.min()), float(t.max()))
84
85
86

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
87
                norm_range(t, value_range)
88
        else:
89
            norm_range(tensor, value_range)
90

91
    if tensor.size(0) == 1:
92
        return tensor.squeeze(0)
93

94
95
96
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
97
    ymaps = int(math.ceil(float(nmaps) / xmaps))
98
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
99
100
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
101
    k = 0
102
103
    for y in range(ymaps):
        for x in range(xmaps):
104
105
            if k >= nmaps:
                break
106
107
108
109
110
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                2, x * width + padding, width - padding
            ).copy_(tensor[k])
111
112
113
114
            k = k + 1
    return grid


115
@torch.no_grad()
116
def save_image(
117
    tensor: Union[torch.Tensor, List[torch.Tensor]],
118
    fp: Union[str, pathlib.Path, BinaryIO],
119
    format: Optional[str] = None,
120
    **kwargs,
121
) -> None:
122
123
    """
    Save a given Tensor into an image file.
124
125
126
127

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
128
        fp (string or file object): A filename or a file object
129
130
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
131
        **kwargs: Other arguments are documented in ``make_grid``.
132
    """
133

134
    _log_api_usage_once("utils", "save_image")
135
    grid = make_grid(tensor, **kwargs)
136
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
137
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
138
    im = Image.fromarray(ndarr)
139
    im.save(fp, format=format)
140
141
142
143
144
145
146


@torch.no_grad()
def draw_bounding_boxes(
    image: torch.Tensor,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None,
147
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
148
    fill: Optional[bool] = False,
149
150
    width: int = 1,
    font: Optional[str] = None,
151
    font_size: int = 10,
152
153
154
155
156
) -> torch.Tensor:

    """
    Draws bounding boxes on given image.
    The values of the input image should be uint8 between 0 and 255.
157
    If fill is True, Resulting Tensor should be saved as PNG image.
158
159

    Args:
160
        image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
161
        boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
162
163
164
            the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
            `0 <= ymin < ymax < H`.
        labels (List[str]): List containing the labels of bounding boxes.
165
166
167
        colors (color or list of colors, optional): List containing the colors
            of the boxes or single color for all boxes. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
168
        fill (bool): If `True` fills the bounding box with specified color.
169
170
171
172
173
        width (int): Width of bounding box.
        font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
            also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
            `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
        font_size (int): The requested font size in points.
174
175
176

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
177
178
    """

179
    _log_api_usage_once("utils", "draw_bounding_boxes")
180
181
182
183
184
185
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
186
187
188
189
190
    elif image.size(0) not in {1, 3}:
        raise ValueError("Only grayscale and RGB images are supported")

    if image.size(0) == 1:
        image = torch.tile(image, (3, 1, 1))
191

192
    ndarr = image.permute(1, 2, 0).cpu().numpy()
193
194
195
196
    img_to_draw = Image.fromarray(ndarr)

    img_boxes = boxes.to(torch.int64).tolist()

197
198
199
200
201
202
    if fill:
        draw = ImageDraw.Draw(img_to_draw, "RGBA")

    else:
        draw = ImageDraw.Draw(img_to_draw)

203
    txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
204
205

    for i, bbox in enumerate(img_boxes):
206
207
        if colors is None:
            color = None
208
        elif isinstance(colors, list):
209
            color = colors[i]
210
211
        else:
            color = colors
212
213
214
215
216
217
218
219
220
221
222
223

        if fill:
            if color is None:
                fill_color = (255, 255, 255, 100)
            elif isinstance(color, str):
                # This will automatically raise Error if rgb cannot be parsed.
                fill_color = ImageColor.getrgb(color) + (100,)
            elif isinstance(color, tuple):
                fill_color = color + (100,)
            draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
        else:
            draw.rectangle(bbox, width=width, outline=color)
224
225

        if labels is not None:
226
227
            margin = width + 1
            draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font)
228

229
    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
230
231
232
233
234
235


@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
236
    alpha: float = 0.8,
237
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
238
239
240
241
242
243
244
) -> torch.Tensor:

    """
    Draws segmentation masks on given RGB image.
    The values of the input image should be uint8 between 0 and 255.

    Args:
245
246
247
248
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
        masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
        alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
            0 means full transparency, 1 means no transparency.
249
250
251
252
        colors (color or list of colors, optional): List containing the colors
            of the masks or single color for all masks. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
            By default, random colors are generated for each mask.
253
254

    Returns:
255
        img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
256
257
    """

258
    _log_api_usage_once("utils", "draw_segmentation_masks")
259
    if not isinstance(image, torch.Tensor):
260
        raise TypeError(f"The image must be a tensor, got {type(image)}")
261
    elif image.dtype != torch.uint8:
262
        raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
263
264
265
266
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")
267
268
269
270
271
272
273
274
    if masks.ndim == 2:
        masks = masks[None, :, :]
    if masks.ndim != 3:
        raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
    if masks.dtype != torch.bool:
        raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
    if masks.shape[-2:] != image.shape[-2:]:
        raise ValueError("The image and the masks must have the same height and width")
275
276

    num_masks = masks.size()[0]
277
278
    if colors is not None and num_masks > len(colors):
        raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
279
280

    if colors is None:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        colors = _generate_color_palette(num_masks)

    if not isinstance(colors, list):
        colors = [colors]
    if not isinstance(colors[0], (tuple, str)):
        raise ValueError("colors must be a tuple or a string, or a list thereof")
    if isinstance(colors[0], tuple) and len(colors[0]) != 3:
        raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")

    out_dtype = torch.uint8

    colors_ = []
    for color in colors:
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
296
        colors_.append(torch.tensor(color, dtype=out_dtype))
297

298
299
300
301
    img_to_draw = image.detach().clone()
    # TODO: There might be a way to vectorize this
    for mask, color in zip(masks, colors_):
        img_to_draw[:, mask] = color[:, None]
302

303
304
    out = image * (1 - alpha) + img_to_draw * alpha
    return out.to(out_dtype)
305
306


307
308
309
310
@torch.no_grad()
def draw_keypoints(
    image: torch.Tensor,
    keypoints: torch.Tensor,
311
    connectivity: Optional[List[Tuple[int, int]]] = None,
312
313
314
315
316
317
318
319
320
321
322
323
324
    colors: Optional[Union[str, Tuple[int, int, int]]] = None,
    radius: int = 2,
    width: int = 3,
) -> torch.Tensor:

    """
    Draws Keypoints on given RGB image.
    The values of the input image should be uint8 between 0 and 255.

    Args:
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
        keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
            in the format [x, y].
325
        connectivity (List[Tuple[int, int]]]): A List of tuple where,
326
327
328
329
330
331
332
333
334
335
            each tuple contains pair of keypoints to be connected.
        colors (str, Tuple): The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
        radius (int): Integer denoting radius of keypoint.
        width (int): Integer denoting width of line connecting keypoints.

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
    """

336
    _log_api_usage_once("utils", "draw_keypoints")
337
338
339
340
341
342
343
344
345
346
347
348
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"The image must be a tensor, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")

    if keypoints.ndim != 3:
        raise ValueError("keypoints must be of shape (num_instances, K, 2)")

349
    ndarr = image.permute(1, 2, 0).cpu().numpy()
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    img_to_draw = Image.fromarray(ndarr)
    draw = ImageDraw.Draw(img_to_draw)
    img_kpts = keypoints.to(torch.int64).tolist()

    for kpt_id, kpt_inst in enumerate(img_kpts):
        for inst_id, kpt in enumerate(kpt_inst):
            x1 = kpt[0] - radius
            x2 = kpt[0] + radius
            y1 = kpt[1] - radius
            y2 = kpt[1] + radius
            draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

        if connectivity:
            for connection in connectivity:
                start_pt_x = kpt_inst[connection[0]][0]
                start_pt_y = kpt_inst[connection[0]][1]

                end_pt_x = kpt_inst[connection[1]][0]
                end_pt_y = kpt_inst[connection[1]][1]

                draw.line(
                    ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
                    width=width,
                )

    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


378
def _generate_color_palette(num_masks: int):
379
380
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    return [tuple((i * palette) % 255) for i in range(num_masks)]
381
382


383
def _log_api_usage_once(module: str, name: str) -> None:
384
385
    if torch.jit.is_scripting() or torch.jit.is_tracing():
        return
386
    torch._C._log_api_usage_once(f"torchvision.{module}.{name}")