utils.py 20.3 KB
Newer Older
1
import math
2
import pathlib
3
import warnings
Kai Zhang's avatar
Kai Zhang committed
4
from types import FunctionType
5
from typing import Any, BinaryIO, List, Optional, Tuple, Union
6

7
import numpy as np
8
import torch
9
from PIL import Image, ImageColor, ImageDraw, ImageFont
10

11
12
13
14
15
16
17
18
__all__ = [
    "make_grid",
    "save_image",
    "draw_bounding_boxes",
    "draw_segmentation_masks",
    "draw_keypoints",
    "flow_to_image",
]
19

20

21
@torch.no_grad()
22
def make_grid(
23
    tensor: Union[torch.Tensor, List[torch.Tensor]],
24
25
26
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
27
    value_range: Optional[Tuple[int, int]] = None,
28
    scale_each: bool = False,
29
    pad_value: float = 0.0,
30
    **kwargs,
31
) -> torch.Tensor:
32
33
    """
    Make a grid of images.
34

35
36
37
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
38
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
39
40
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
41
        normalize (bool, optional): If True, shift the image to the range (0, 1),
42
            by the min and max values specified by ``value_range``. Default: ``False``.
43
        value_range (tuple, optional): tuple (min, max) where min and max are numbers,
44
45
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
46
47
48
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
49

50
51
    Returns:
        grid (Tensor): the tensor containing grid of images.
52
    """
Kai Zhang's avatar
Kai Zhang committed
53
54
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(make_grid)
55
56
    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
57
58
59
60
61

    if "range" in kwargs.keys():
        warning = "range will be deprecated, please use value_range instead."
        warnings.warn(warning)
        value_range = kwargs["range"]
62

63
    # if list of tensors, convert to a 4D mini-batch Tensor
64
    if isinstance(tensor, list):
65
        tensor = torch.stack(tensor, dim=0)
66

67
    if tensor.dim() == 2:  # single image H x W
68
        tensor = tensor.unsqueeze(0)
69
    if tensor.dim() == 3:  # single image
70
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
71
            tensor = torch.cat((tensor, tensor, tensor), 0)
72
        tensor = tensor.unsqueeze(0)
73

74
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
75
        tensor = torch.cat((tensor, tensor, tensor), 1)
76
77

    if normalize is True:
78
        tensor = tensor.clone()  # avoid modifying tensor in-place
79
        if value_range is not None:
80
81
82
            assert isinstance(
                value_range, tuple
            ), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
83

84
85
86
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))
87

88
89
90
        def norm_range(t, value_range):
            if value_range is not None:
                norm_ip(t, value_range[0], value_range[1])
91
            else:
92
                norm_ip(t, float(t.min()), float(t.max()))
93
94
95

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
96
                norm_range(t, value_range)
97
        else:
98
            norm_range(tensor, value_range)
99

Kai Zhang's avatar
Kai Zhang committed
100
    assert isinstance(tensor, torch.Tensor)
101
    if tensor.size(0) == 1:
102
        return tensor.squeeze(0)
103

104
105
106
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
107
    ymaps = int(math.ceil(float(nmaps) / xmaps))
108
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
109
110
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
111
    k = 0
112
113
    for y in range(ymaps):
        for x in range(xmaps):
114
115
            if k >= nmaps:
                break
116
117
118
119
120
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                2, x * width + padding, width - padding
            ).copy_(tensor[k])
121
122
123
124
            k = k + 1
    return grid


125
@torch.no_grad()
126
def save_image(
127
    tensor: Union[torch.Tensor, List[torch.Tensor]],
128
    fp: Union[str, pathlib.Path, BinaryIO],
129
    format: Optional[str] = None,
130
    **kwargs,
131
) -> None:
132
133
    """
    Save a given Tensor into an image file.
134
135
136
137

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
138
        fp (string or file object): A filename or a file object
139
140
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
141
        **kwargs: Other arguments are documented in ``make_grid``.
142
    """
143

Kai Zhang's avatar
Kai Zhang committed
144
145
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(save_image)
146
    grid = make_grid(tensor, **kwargs)
147
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
148
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
149
    im = Image.fromarray(ndarr)
150
    im.save(fp, format=format)
151
152
153
154
155
156
157


@torch.no_grad()
def draw_bounding_boxes(
    image: torch.Tensor,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None,
158
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
159
    fill: Optional[bool] = False,
160
161
    width: int = 1,
    font: Optional[str] = None,
162
    font_size: int = 10,
163
164
165
166
167
) -> torch.Tensor:

    """
    Draws bounding boxes on given image.
    The values of the input image should be uint8 between 0 and 255.
168
    If fill is True, Resulting Tensor should be saved as PNG image.
169
170

    Args:
171
        image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
172
        boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
173
174
175
            the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
            `0 <= ymin < ymax < H`.
        labels (List[str]): List containing the labels of bounding boxes.
176
177
178
        colors (color or list of colors, optional): List containing the colors
            of the boxes or single color for all boxes. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
179
        fill (bool): If `True` fills the bounding box with specified color.
180
181
182
183
184
        width (int): Width of bounding box.
        font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
            also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
            `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
        font_size (int): The requested font size in points.
185
186
187

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
188
189
    """

Kai Zhang's avatar
Kai Zhang committed
190
191
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_bounding_boxes)
192
193
194
195
196
197
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
198
199
200
201
202
    elif image.size(0) not in {1, 3}:
        raise ValueError("Only grayscale and RGB images are supported")

    if image.size(0) == 1:
        image = torch.tile(image, (3, 1, 1))
203

204
    ndarr = image.permute(1, 2, 0).cpu().numpy()
205
206
207
208
    img_to_draw = Image.fromarray(ndarr)

    img_boxes = boxes.to(torch.int64).tolist()

209
210
211
212
213
214
    if fill:
        draw = ImageDraw.Draw(img_to_draw, "RGBA")

    else:
        draw = ImageDraw.Draw(img_to_draw)

215
    txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
216
217

    for i, bbox in enumerate(img_boxes):
218
219
        if colors is None:
            color = None
220
        elif isinstance(colors, list):
221
            color = colors[i]
222
223
        else:
            color = colors
224
225
226
227
228
229
230
231
232
233
234
235

        if fill:
            if color is None:
                fill_color = (255, 255, 255, 100)
            elif isinstance(color, str):
                # This will automatically raise Error if rgb cannot be parsed.
                fill_color = ImageColor.getrgb(color) + (100,)
            elif isinstance(color, tuple):
                fill_color = color + (100,)
            draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
        else:
            draw.rectangle(bbox, width=width, outline=color)
236
237

        if labels is not None:
238
239
            margin = width + 1
            draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font)
240

241
    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
242
243
244
245
246
247


@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
248
    alpha: float = 0.8,
249
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
250
251
252
253
254
255
256
) -> torch.Tensor:

    """
    Draws segmentation masks on given RGB image.
    The values of the input image should be uint8 between 0 and 255.

    Args:
257
258
259
260
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
        masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
        alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
            0 means full transparency, 1 means no transparency.
261
262
263
264
        colors (color or list of colors, optional): List containing the colors
            of the masks or single color for all masks. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
            By default, random colors are generated for each mask.
265
266

    Returns:
267
        img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
268
269
    """

Kai Zhang's avatar
Kai Zhang committed
270
271
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_segmentation_masks)
272
    if not isinstance(image, torch.Tensor):
273
        raise TypeError(f"The image must be a tensor, got {type(image)}")
274
    elif image.dtype != torch.uint8:
275
        raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
276
277
278
279
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")
280
281
282
283
284
285
286
287
    if masks.ndim == 2:
        masks = masks[None, :, :]
    if masks.ndim != 3:
        raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
    if masks.dtype != torch.bool:
        raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
    if masks.shape[-2:] != image.shape[-2:]:
        raise ValueError("The image and the masks must have the same height and width")
288
289

    num_masks = masks.size()[0]
290
291
    if colors is not None and num_masks > len(colors):
        raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
292
293

    if colors is None:
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        colors = _generate_color_palette(num_masks)

    if not isinstance(colors, list):
        colors = [colors]
    if not isinstance(colors[0], (tuple, str)):
        raise ValueError("colors must be a tuple or a string, or a list thereof")
    if isinstance(colors[0], tuple) and len(colors[0]) != 3:
        raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")

    out_dtype = torch.uint8

    colors_ = []
    for color in colors:
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
309
        colors_.append(torch.tensor(color, dtype=out_dtype))
310

311
312
313
314
    img_to_draw = image.detach().clone()
    # TODO: There might be a way to vectorize this
    for mask, color in zip(masks, colors_):
        img_to_draw[:, mask] = color[:, None]
315

316
317
    out = image * (1 - alpha) + img_to_draw * alpha
    return out.to(out_dtype)
318
319


320
321
322
323
@torch.no_grad()
def draw_keypoints(
    image: torch.Tensor,
    keypoints: torch.Tensor,
324
    connectivity: Optional[List[Tuple[int, int]]] = None,
325
326
327
328
329
330
331
332
333
334
335
336
337
    colors: Optional[Union[str, Tuple[int, int, int]]] = None,
    radius: int = 2,
    width: int = 3,
) -> torch.Tensor:

    """
    Draws Keypoints on given RGB image.
    The values of the input image should be uint8 between 0 and 255.

    Args:
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
        keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
            in the format [x, y].
338
        connectivity (List[Tuple[int, int]]]): A List of tuple where,
339
340
341
342
343
344
345
346
347
348
            each tuple contains pair of keypoints to be connected.
        colors (str, Tuple): The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
        radius (int): Integer denoting radius of keypoint.
        width (int): Integer denoting width of line connecting keypoints.

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
    """

Kai Zhang's avatar
Kai Zhang committed
349
350
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_keypoints)
351
352
353
354
355
356
357
358
359
360
361
362
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"The image must be a tensor, got {type(image)}")
    elif image.dtype != torch.uint8:
        raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")

    if keypoints.ndim != 3:
        raise ValueError("keypoints must be of shape (num_instances, K, 2)")

363
    ndarr = image.permute(1, 2, 0).cpu().numpy()
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    img_to_draw = Image.fromarray(ndarr)
    draw = ImageDraw.Draw(img_to_draw)
    img_kpts = keypoints.to(torch.int64).tolist()

    for kpt_id, kpt_inst in enumerate(img_kpts):
        for inst_id, kpt in enumerate(kpt_inst):
            x1 = kpt[0] - radius
            x2 = kpt[0] + radius
            y1 = kpt[1] - radius
            y2 = kpt[1] + radius
            draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

        if connectivity:
            for connection in connectivity:
                start_pt_x = kpt_inst[connection[0]][0]
                start_pt_y = kpt_inst[connection[0]][1]

                end_pt_x = kpt_inst[connection[1]][0]
                end_pt_y = kpt_inst[connection[1]][1]

                draw.line(
                    ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
                    width=width,
                )

    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


392
393
394
395
396
397
398
399
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:

    """
    Converts a flow to an RGB image.

    Args:
400
        flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
401
402

    Returns:
403
404
        img (Tensor): Image Tensor of dtype uint8 where each color corresponds
            to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
405
406
407
408
409
    """

    if flow.dtype != torch.float:
        raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")

410
411
412
    orig_shape = flow.shape
    if flow.ndim == 3:
        flow = flow[None]  # Add batch dim
413

414
415
416
417
    if flow.ndim != 4 or flow.shape[1] != 2:
        raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")

    max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
418
419
    epsilon = torch.finfo((flow).dtype).eps
    normalized_flow = flow / (max_norm + epsilon)
420
421
422
423
424
    img = _normalized_flow_to_image(normalized_flow)

    if len(orig_shape) == 3:
        img = img[0]  # Remove batch dim
    return img
425
426
427
428
429
430


@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:

    """
431
    Converts a batch of normalized flow to an RGB image.
432
433

    Args:
434
        normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
435
    Returns:
436
       img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
437
438
    """

439
440
    N, _, H, W = normalized_flow.shape
    flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
441
442
    colorwheel = _make_colorwheel()  # shape [55x3]
    num_cols = colorwheel.shape[0]
443
444
    norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
    a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
445
446
447
448
449
450
451
452
453
454
455
456
    fk = (a + 1) / 2 * (num_cols - 1)
    k0 = torch.floor(fk).to(torch.long)
    k1 = k0 + 1
    k1[k1 == num_cols] = 0
    f = fk - k0

    for c in range(colorwheel.shape[1]):
        tmp = colorwheel[:, c]
        col0 = tmp[k0] / 255.0
        col1 = tmp[k1] / 255.0
        col = (1 - f) * col0 + f * col1
        col = 1 - norm * (1 - col)
457
        flow_image[:, c, :, :] = torch.floor(255 * col)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    return flow_image


def _make_colorwheel() -> torch.Tensor:
    """
    Generates a color wheel for optical flow visualization as presented in:
    Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
    URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.

    Returns:
        colorwheel (Tensor[55, 3]): Colorwheel Tensor.
    """

    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = torch.zeros((ncols, 3))
    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
    col = col + RY
    # YG
    colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
    colorwheel[col : col + YG, 1] = 255
    col = col + YG
    # GC
    colorwheel[col : col + GC, 1] = 255
    colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
    col = col + GC
    # CB
    colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
    colorwheel[col : col + CB, 2] = 255
    col = col + CB
    # BM
    colorwheel[col : col + BM, 2] = 255
    colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
    col = col + BM
    # MR
    colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
    colorwheel[col : col + MR, 0] = 255
    return colorwheel


508
def _generate_color_palette(num_masks: int):
509
510
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    return [tuple((i * palette) % 255) for i in range(num_masks)]
511
512


Kai Zhang's avatar
Kai Zhang committed
513
def _log_api_usage_once(obj: Any) -> None:
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    """
    Logs API usage(module and name) within an organization.
    In a large ecosystem, it's often useful to track the PyTorch and
    TorchVision APIs usage. This API provides the similar functionality to the
    logging module in the Python stdlib. It can be used for debugging purpose
    to log which methods are used and by default it is inactive, unless the user
    manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
    Please note it is triggered only once for the same API call within a process.
    It does not collect any data from open-source users since it is no-op by default.
    For more information, please refer to
    * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
    * Logging policy: https://github.com/pytorch/vision/issues/5052;

    Args:
        obj (class instance or method): an object to extract info from.
    """
Kai Zhang's avatar
Kai Zhang committed
531
    if not obj.__module__.startswith("torchvision"):
532
        return
Kai Zhang's avatar
Kai Zhang committed
533
534
535
536
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
    torch._C._log_api_usage_once(f"{obj.__module__}.{name}")