utils.py 26 KB
Newer Older
1
import collections
2
import math
3
import pathlib
4
import warnings
5
from itertools import repeat
Kai Zhang's avatar
Kai Zhang committed
6
from types import FunctionType
7
from typing import Any, BinaryIO, List, Optional, Tuple, Union
8

9
import numpy as np
10
import torch
11
from PIL import Image, ImageColor, ImageDraw, ImageFont
12

13

14
15
16
17
18
19
20
21
__all__ = [
    "make_grid",
    "save_image",
    "draw_bounding_boxes",
    "draw_segmentation_masks",
    "draw_keypoints",
    "flow_to_image",
]
22

23

24
@torch.no_grad()
25
def make_grid(
26
    tensor: Union[torch.Tensor, List[torch.Tensor]],
27
28
29
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
30
    value_range: Optional[Tuple[int, int]] = None,
31
    scale_each: bool = False,
32
    pad_value: float = 0.0,
33
) -> torch.Tensor:
34
35
    """
    Make a grid of images.
36

37
38
39
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
40
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
41
42
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
43
        normalize (bool, optional): If True, shift the image to the range (0, 1),
44
            by the min and max values specified by ``value_range``. Default: ``False``.
45
        value_range (tuple, optional): tuple (min, max) where min and max are numbers,
46
47
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
48
49
50
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
51

52
53
    Returns:
        grid (Tensor): the tensor containing grid of images.
54
    """
Kai Zhang's avatar
Kai Zhang committed
55
56
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(make_grid)
57
58
59
60
61
62
63
    if not torch.is_tensor(tensor):
        if isinstance(tensor, list):
            for t in tensor:
                if not torch.is_tensor(t):
                    raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
        else:
            raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
64

65
    # if list of tensors, convert to a 4D mini-batch Tensor
66
    if isinstance(tensor, list):
67
        tensor = torch.stack(tensor, dim=0)
68

69
    if tensor.dim() == 2:  # single image H x W
70
        tensor = tensor.unsqueeze(0)
71
    if tensor.dim() == 3:  # single image
72
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
73
            tensor = torch.cat((tensor, tensor, tensor), 0)
74
        tensor = tensor.unsqueeze(0)
75

76
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
77
        tensor = torch.cat((tensor, tensor, tensor), 1)
78
79

    if normalize is True:
80
        tensor = tensor.clone()  # avoid modifying tensor in-place
81
82
        if value_range is not None and not isinstance(value_range, tuple):
            raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
83

84
85
86
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))
87

88
89
90
        def norm_range(t, value_range):
            if value_range is not None:
                norm_ip(t, value_range[0], value_range[1])
91
            else:
92
                norm_ip(t, float(t.min()), float(t.max()))
93
94
95

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
96
                norm_range(t, value_range)
97
        else:
98
            norm_range(tensor, value_range)
99

100
101
    if not isinstance(tensor, torch.Tensor):
        raise TypeError("tensor should be of type torch.Tensor")
102
    if tensor.size(0) == 1:
103
        return tensor.squeeze(0)
104

105
106
107
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
108
    ymaps = int(math.ceil(float(nmaps) / xmaps))
109
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
110
111
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
112
    k = 0
113
114
    for y in range(ymaps):
        for x in range(xmaps):
115
116
            if k >= nmaps:
                break
117
118
119
120
121
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                2, x * width + padding, width - padding
            ).copy_(tensor[k])
122
123
124
125
            k = k + 1
    return grid


126
@torch.no_grad()
127
def save_image(
128
    tensor: Union[torch.Tensor, List[torch.Tensor]],
129
    fp: Union[str, pathlib.Path, BinaryIO],
130
    format: Optional[str] = None,
131
    **kwargs,
132
) -> None:
133
134
    """
    Save a given Tensor into an image file.
135
136
137
138

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
139
        fp (string or file object): A filename or a file object
140
141
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
142
        **kwargs: Other arguments are documented in ``make_grid``.
143
    """
144

Kai Zhang's avatar
Kai Zhang committed
145
146
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(save_image)
147
    grid = make_grid(tensor, **kwargs)
148
    # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
149
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
150
    im = Image.fromarray(ndarr)
151
    im.save(fp, format=format)
152
153
154
155
156
157
158


@torch.no_grad()
def draw_bounding_boxes(
    image: torch.Tensor,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None,
159
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
160
    fill: Optional[bool] = False,
161
162
    width: int = 1,
    font: Optional[str] = None,
163
    font_size: Optional[int] = None,
164
165
166
) -> torch.Tensor:

    """
167
168
    Draws bounding boxes on given RGB image.
    The image values should be uint8 in [0, 255] or float in [0, 1].
169
    If fill is True, Resulting Tensor should be saved as PNG image.
170
171

    Args:
172
        image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
173
        boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
174
175
176
            the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
            `0 <= ymin < ymax < H`.
        labels (List[str]): List containing the labels of bounding boxes.
177
178
179
        colors (color or list of colors, optional): List containing the colors
            of the boxes or single color for all boxes. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
180
            By default, random colors are generated for boxes.
181
        fill (bool): If `True` fills the bounding box with specified color.
182
183
184
185
186
        width (int): Width of bounding box.
        font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
            also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
            `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
        font_size (int): The requested font size in points.
187
188
189

    Returns:
        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
190
    """
191
    import torchvision.transforms.v2.functional as F  # noqa
192

Kai Zhang's avatar
Kai Zhang committed
193
194
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_bounding_boxes)
195
196
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Tensor expected, got {type(image)}")
197
198
    elif not (image.dtype == torch.uint8 or image.is_floating_point()):
        raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
199
200
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
201
202
    elif image.size(0) not in {1, 3}:
        raise ValueError("Only grayscale and RGB images are supported")
203
204
205
206
    elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
        raise ValueError(
            "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
        )
207

208
209
    num_boxes = boxes.shape[0]

210
211
212
213
    if num_boxes == 0:
        warnings.warn("boxes doesn't contain any box. No box was drawn")
        return image

214
215
216
217
218
219
220
    if labels is None:
        labels: Union[List[str], List[None]] = [None] * num_boxes  # type: ignore[no-redef]
    elif len(labels) != num_boxes:
        raise ValueError(
            f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
        )

221
    colors = _parse_colors(colors, num_objects=num_boxes)
222

223
224
225
226
227
228
229
    if font is None:
        if font_size is not None:
            warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
        txt_font = ImageFont.load_default()
    else:
        txt_font = ImageFont.truetype(font=font, size=font_size or 10)

230
    # Handle Grayscale images
231
232
    if image.size(0) == 1:
        image = torch.tile(image, (3, 1, 1))
233

234
235
236
237
238
    original_dtype = image.dtype
    if original_dtype.is_floating_point:
        image = F.to_dtype(image, dtype=torch.uint8, scale=True)

    img_to_draw = F.to_pil_image(image)
239
240
    img_boxes = boxes.to(torch.int64).tolist()

241
242
243
244
245
    if fill:
        draw = ImageDraw.Draw(img_to_draw, "RGBA")
    else:
        draw = ImageDraw.Draw(img_to_draw)

246
    for bbox, color, label in zip(img_boxes, colors, labels):  # type: ignore[arg-type]
247
        if fill:
248
            fill_color = color + (100,)
249
250
251
            draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
        else:
            draw.rectangle(bbox, width=width, outline=color)
252

253
        if label is not None:
254
            margin = width + 1
255
            draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
256

257
258
259
260
    out = F.pil_to_tensor(img_to_draw)
    if original_dtype.is_floating_point:
        out = F.to_dtype(out, dtype=original_dtype, scale=True)
    return out
261
262
263
264
265
266


@torch.no_grad()
def draw_segmentation_masks(
    image: torch.Tensor,
    masks: torch.Tensor,
267
    alpha: float = 0.8,
268
    colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
269
270
271
272
) -> torch.Tensor:

    """
    Draws segmentation masks on given RGB image.
273
    The image values should be uint8 in [0, 255] or float in [0, 1].
274
275

    Args:
276
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
277
278
279
        masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
        alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
            0 means full transparency, 1 means no transparency.
280
281
282
283
        colors (color or list of colors, optional): List containing the colors
            of the masks or single color for all masks. The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
            By default, random colors are generated for each mask.
284
285

    Returns:
286
        img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
287
288
    """

Kai Zhang's avatar
Kai Zhang committed
289
290
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_segmentation_masks)
291
    if not isinstance(image, torch.Tensor):
292
        raise TypeError(f"The image must be a tensor, got {type(image)}")
293
294
    elif not (image.dtype == torch.uint8 or image.is_floating_point()):
        raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
295
296
297
298
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")
299
300
301
302
303
304
305
306
    if masks.ndim == 2:
        masks = masks[None, :, :]
    if masks.ndim != 3:
        raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
    if masks.dtype != torch.bool:
        raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
    if masks.shape[-2:] != image.shape[-2:]:
        raise ValueError("The image and the masks must have the same height and width")
307
308

    num_masks = masks.size()[0]
309
    overlapping_masks = masks.sum(dim=0) > 1
310

311
312
313
314
    if num_masks == 0:
        warnings.warn("masks doesn't contain any mask. No mask was drawn")
        return image

315
    original_dtype = image.dtype
316
    colors = [
317
318
        torch.tensor(color, dtype=original_dtype, device=image.device)
        for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype)
319
    ]
320

321
322
    img_to_draw = image.detach().clone()
    # TODO: There might be a way to vectorize this
323
    for mask, color in zip(masks, colors):
324
        img_to_draw[:, mask] = color[:, None]
325

326
327
    img_to_draw[:, overlapping_masks] = 0

328
    out = image * (1 - alpha) + img_to_draw * alpha
329
330
    # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
    return out.to(original_dtype)
331
332


333
334
335
336
@torch.no_grad()
def draw_keypoints(
    image: torch.Tensor,
    keypoints: torch.Tensor,
337
    connectivity: Optional[List[Tuple[int, int]]] = None,
338
339
340
    colors: Optional[Union[str, Tuple[int, int, int]]] = None,
    radius: int = 2,
    width: int = 3,
341
    visibility: Optional[torch.Tensor] = None,
342
343
344
345
) -> torch.Tensor:

    """
    Draws Keypoints on given RGB image.
346
    The image values should be uint8 in [0, 255] or float in [0, 1].
347
348
349
    Keypoints can be drawn for multiple instances at a time.

    This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
350
351

    Args:
352
        image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
353
        keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
354
            in the format [x, y].
355
356
357
358
359
        connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
            to be connected.
            If at least one of the two connected keypoints has a ``visibility`` of False,
            this specific connection is not drawn.
            Exclusions due to invisibility are computed per-instance.
360
361
362
363
        colors (str, Tuple): The color can be represented as
            PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
        radius (int): Integer denoting radius of keypoint.
        width (int): Integer denoting width of line connecting keypoints.
364
365
366
367
368
369
370
        visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
            keypoints for each of the N instances.
            True means that the respective keypoint is visible and should be drawn.
            False means invisible, so neither the point nor possible connections containing it are drawn.
            The input tensor will be cast to bool.
            Default ``None`` means that all the keypoints are visible.
            For more details, see :ref:`draw_keypoints_with_visibility`.
371
372

    Returns:
373
        img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
374
375
    """

Kai Zhang's avatar
Kai Zhang committed
376
377
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(draw_keypoints)
378
    # validate image
379
380
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"The image must be a tensor, got {type(image)}")
381
382
    elif not (image.dtype == torch.uint8 or image.is_floating_point()):
        raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
383
384
385
386
387
    elif image.dim() != 3:
        raise ValueError("Pass individual images, not batches")
    elif image.size()[0] != 3:
        raise ValueError("Pass an RGB image. Other Image formats are not supported")

388
    # validate keypoints
389
390
391
    if keypoints.ndim != 3:
        raise ValueError("keypoints must be of shape (num_instances, K, 2)")

392
393
394
    # validate visibility
    if visibility is None:  # set default
        visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
395
396
397
398
    if visibility.ndim == 3:
        # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1).
        # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice.
        visibility = visibility.squeeze(-1)
399
400
401
402
403
404
405
406
    if visibility.ndim != 2:
        raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
    if visibility.shape != keypoints.shape[:-1]:
        raise ValueError(
            "keypoints and visibility must have the same dimensionality for num_instances and K. "
            f"Got {visibility.shape = } and {keypoints.shape = }"
        )

407
408
409
410
411
412
    original_dtype = image.dtype
    if original_dtype.is_floating_point:
        from torchvision.transforms.v2.functional import to_dtype  # noqa

        image = to_dtype(image, dtype=torch.uint8, scale=True)

413
    ndarr = image.permute(1, 2, 0).cpu().numpy()
414
415
416
    img_to_draw = Image.fromarray(ndarr)
    draw = ImageDraw.Draw(img_to_draw)
    img_kpts = keypoints.to(torch.int64).tolist()
417
418
419
420
421
422
423
424
425
426
    img_vis = visibility.cpu().bool().tolist()

    for kpt_inst, vis_inst in zip(img_kpts, img_vis):
        for kpt_coord, kp_vis in zip(kpt_inst, vis_inst):
            if not kp_vis:
                continue
            x1 = kpt_coord[0] - radius
            x2 = kpt_coord[0] + radius
            y1 = kpt_coord[1] - radius
            y2 = kpt_coord[1] + radius
427
428
429
430
            draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

        if connectivity:
            for connection in connectivity:
431
432
                if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]):
                    continue
433
434
435
436
437
438
439
440
441
442
443
                start_pt_x = kpt_inst[connection[0]][0]
                start_pt_y = kpt_inst[connection[0]][1]

                end_pt_x = kpt_inst[connection[1]][0]
                end_pt_y = kpt_inst[connection[1]][1]

                draw.line(
                    ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
                    width=width,
                )

444
445
446
447
    out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
    if original_dtype.is_floating_point:
        out = to_dtype(out, dtype=original_dtype, scale=True)
    return out
448
449


450
451
452
453
454
455
456
457
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:

    """
    Converts a flow to an RGB image.

    Args:
458
        flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
459
460

    Returns:
461
462
        img (Tensor): Image Tensor of dtype uint8 where each color corresponds
            to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
463
464
465
466
467
    """

    if flow.dtype != torch.float:
        raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")

468
469
470
    orig_shape = flow.shape
    if flow.ndim == 3:
        flow = flow[None]  # Add batch dim
471

472
473
474
    if flow.ndim != 4 or flow.shape[1] != 2:
        raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")

475
    max_norm = torch.sum(flow**2, dim=1).sqrt().max()
476
477
    epsilon = torch.finfo((flow).dtype).eps
    normalized_flow = flow / (max_norm + epsilon)
478
479
480
481
482
    img = _normalized_flow_to_image(normalized_flow)

    if len(orig_shape) == 3:
        img = img[0]  # Remove batch dim
    return img
483
484
485
486
487
488


@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:

    """
489
    Converts a batch of normalized flow to an RGB image.
490
491

    Args:
492
        normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
493
    Returns:
494
       img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
495
496
    """

497
    N, _, H, W = normalized_flow.shape
498
499
500
    device = normalized_flow.device
    flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
    colorwheel = _make_colorwheel().to(device)  # shape [55x3]
501
    num_cols = colorwheel.shape[0]
502
    norm = torch.sum(normalized_flow**2, dim=1).sqrt()
503
    a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
504
505
506
507
508
509
510
511
512
513
514
515
    fk = (a + 1) / 2 * (num_cols - 1)
    k0 = torch.floor(fk).to(torch.long)
    k1 = k0 + 1
    k1[k1 == num_cols] = 0
    f = fk - k0

    for c in range(colorwheel.shape[1]):
        tmp = colorwheel[:, c]
        col0 = tmp[k0] / 255.0
        col1 = tmp[k1] / 255.0
        col = (1 - f) * col0 + f * col1
        col = 1 - norm * (1 - col)
516
        flow_image[:, c, :, :] = torch.floor(255 * col)
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    return flow_image


def _make_colorwheel() -> torch.Tensor:
    """
    Generates a color wheel for optical flow visualization as presented in:
    Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
    URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.

    Returns:
        colorwheel (Tensor[55, 3]): Colorwheel Tensor.
    """

    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = torch.zeros((ncols, 3))
    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
    col = col + RY
    # YG
    colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
    colorwheel[col : col + YG, 1] = 255
    col = col + YG
    # GC
    colorwheel[col : col + GC, 1] = 255
    colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
    col = col + GC
    # CB
    colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
    colorwheel[col : col + CB, 2] = 255
    col = col + CB
    # BM
    colorwheel[col : col + BM, 2] = 255
    colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
    col = col + BM
    # MR
    colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
    colorwheel[col : col + MR, 0] = 255
    return colorwheel


567
def _generate_color_palette(num_objects: int):
568
    palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
569
    return [tuple((i * palette) % 255) for i in range(num_objects)]
570
571


572
573
574
575
def _parse_colors(
    colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]],
    *,
    num_objects: int,
576
    dtype: torch.dtype = torch.uint8,
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
) -> List[Tuple[int, int, int]]:
    """
    Parses a specification of colors for a set of objects.

    Args:
        colors: A specification of colors for the objects. This can be one of the following:
            - None: to generate a color palette automatically.
            - A list of colors: where each color is either a string (specifying a named color) or an RGB tuple.
            - A string or an RGB tuple: to use the same color for all objects.

            If `colors` is a tuple, it should be a 3-tuple specifying the RGB values of the color.
            If `colors` is a list, it should have at least as many elements as the number of objects to color.

        num_objects (int): The number of objects to color.

    Returns:
        A list of 3-tuples, specifying the RGB values of the colors.

    Raises:
        ValueError: If the number of colors in the list is less than the number of objects to color.
                    If `colors` is not a list, tuple, string or None.
    """
    if colors is None:
        colors = _generate_color_palette(num_objects)
    elif isinstance(colors, list):
        if len(colors) < num_objects:
            raise ValueError(
                f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
            )
    elif not isinstance(colors, (tuple, str)):
        raise ValueError("`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
    elif isinstance(colors, tuple) and len(colors) != 3:
        raise ValueError("If passed as tuple, colors should be an RGB triplet, but got {colors}.")
    else:  # colors specifies a single color for all objects
        colors = [colors] * num_objects

613
614
615
616
    colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
    if dtype.is_floating_point:  # [0, 255] -> [0, 1]
        colors = [tuple(v / 255 for v in color) for color in colors]
    return colors
617
618


Kai Zhang's avatar
Kai Zhang committed
619
def _log_api_usage_once(obj: Any) -> None:
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

    """
    Logs API usage(module and name) within an organization.
    In a large ecosystem, it's often useful to track the PyTorch and
    TorchVision APIs usage. This API provides the similar functionality to the
    logging module in the Python stdlib. It can be used for debugging purpose
    to log which methods are used and by default it is inactive, unless the user
    manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
    Please note it is triggered only once for the same API call within a process.
    It does not collect any data from open-source users since it is no-op by default.
    For more information, please refer to
    * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
    * Logging policy: https://github.com/pytorch/vision/issues/5052;

    Args:
        obj (class instance or method): an object to extract info from.
    """
637
638
639
    module = obj.__module__
    if not module.startswith("torchvision"):
        module = f"torchvision.internal.{module}"
Kai Zhang's avatar
Kai Zhang committed
640
641
642
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
643
    torch._C._log_api_usage_once(f"{module}.{name}")
644
645
646
647
648


def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
    """
    Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
649
    Otherwise, we will make a tuple of length n, all with value of x.
650
651
652
653
654
655
656
657
658
    reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8

    Args:
        x (Any): input value
        n (int): length of the resulting tuple
    """
    if isinstance(x, collections.abc.Iterable):
        return tuple(x)
    return tuple(repeat(x, n))