from typing import Union, Optional, List, Tuple, Text, BinaryIO import pathlib import torch import math import warnings import numpy as np from PIL import Image, ImageDraw, ImageFont, ImageColor __all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] @torch.no_grad() def make_grid( tensor: Union[torch.Tensor, List[torch.Tensor]], nrow: int = 8, padding: int = 2, normalize: bool = False, value_range: Optional[Tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, **kwargs ) -> torch.Tensor: """ Make a grid of images. Args: tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) or a list of images all of the same size. nrow (int, optional): Number of images displayed in each row of the grid. The final grid size is ``(B / nrow, nrow)``. Default: ``8``. padding (int, optional): amount of padding. Default: ``2``. normalize (bool, optional): If True, shift the image to the range (0, 1), by the min and max values specified by ``value_range``. Default: ``False``. value_range (tuple, optional): tuple (min, max) where min and max are numbers, then these numbers are used to normalize the image. By default, min and max are computed from the tensor. scale_each (bool, optional): If ``True``, scale each image in the batch of images separately rather than the (min, max) over all images. Default: ``False``. pad_value (float, optional): Value for the padded pixels. Default: ``0``. Returns: grid (Tensor): the tensor containing grid of images. """ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') if "range" in kwargs.keys(): warning = "range will be deprecated, please use value_range instead." warnings.warn(warning) value_range = kwargs["range"] # if list of tensors, convert to a 4D mini-batch Tensor if isinstance(tensor, list): tensor = torch.stack(tensor, dim=0) if tensor.dim() == 2: # single image H x W tensor = tensor.unsqueeze(0) if tensor.dim() == 3: # single image if tensor.size(0) == 1: # if single-channel, convert to 3-channel tensor = torch.cat((tensor, tensor, tensor), 0) tensor = tensor.unsqueeze(0) if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images tensor = torch.cat((tensor, tensor, tensor), 1) if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place if value_range is not None: assert isinstance(value_range, tuple), \ "value_range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, low, high): img.clamp_(min=low, max=high) img.sub_(low).div_(max(high - low, 1e-5)) def norm_range(t, value_range): if value_range is not None: norm_ip(t, value_range[0], value_range[1]) else: norm_ip(t, float(t.min()), float(t.max())) if scale_each is True: for t in tensor: # loop over mini-batch dimension norm_range(t, value_range) else: norm_range(tensor, value_range) if tensor.size(0) == 1: return tensor.squeeze(0) # make the mini-batch of images into a grid nmaps = tensor.size(0) xmaps = min(nrow, nmaps) ymaps = int(math.ceil(float(nmaps) / xmaps)) height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) num_channels = tensor.size(1) grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) k = 0 for y in range(ymaps): for x in range(xmaps): if k >= nmaps: break # Tensor.copy_() is a valid method but seems to be missing from the stubs # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] 2, x * width + padding, width - padding ).copy_(tensor[k]) k = k + 1 return grid @torch.no_grad() def save_image( tensor: Union[torch.Tensor, List[torch.Tensor]], fp: Union[Text, pathlib.Path, BinaryIO], format: Optional[str] = None, **kwargs ) -> None: """ Save a given Tensor into an image file. Args: tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling ``make_grid``. fp (string or file object): A filename or a file object format(Optional): If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ grid = make_grid(tensor, **kwargs) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(fp, format=format) @torch.no_grad() def draw_bounding_boxes( image: torch.Tensor, boxes: torch.Tensor, labels: Optional[List[str]] = None, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, fill: Optional[bool] = False, width: int = 1, font: Optional[str] = None, font_size: int = 10 ) -> torch.Tensor: """ Draws bounding boxes on given image. The values of the input image should be uint8 between 0 and 255. If fill is True, Resulting Tensor should be saved as PNG image. Args: image (Tensor): Tensor of shape (C x H x W) and dtype uint8. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. labels (List[str]): List containing the labels of bounding boxes. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can be represented as `str` or `Tuple[int, int, int]`. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. font_size (int): The requested font size in points. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. """ if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") elif image.dtype != torch.uint8: raise ValueError(f"Tensor uint8 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") ndarr = image.permute(1, 2, 0).numpy() # allow single-channel-images # shape: (1, H, W) with C = 1 if ndarr.shape[-1] == 1: ndarr = np.tile(ndarr, (1, 1, 3)) img_to_draw = Image.fromarray(ndarr) img_boxes = boxes.to(torch.int64).tolist() if fill: draw = ImageDraw.Draw(img_to_draw, "RGBA") else: draw = ImageDraw.Draw(img_to_draw) txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) for i, bbox in enumerate(img_boxes): if colors is None: color = None else: color = colors[i] if fill: if color is None: fill_color = (255, 255, 255, 100) elif isinstance(color, str): # This will automatically raise Error if rgb cannot be parsed. fill_color = ImageColor.getrgb(color) + (100,) elif isinstance(color, tuple): fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) else: draw.rectangle(bbox, width=width, outline=color) if labels is not None: margin = width + 1 draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) @torch.no_grad() def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: """ Draws segmentation masks on given RGB image. The values of the input image should be uint8 between 0 and 255. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. colors (list or None): List containing the colors of the masks. The colors can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list with one element. By default, random colors are generated for each mask. Returns: img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. """ if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") if masks.ndim == 2: masks = masks[None, :, :] if masks.ndim != 3: raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") if masks.dtype != torch.bool: raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") if colors is None: colors = _generate_color_palette(num_masks) if not isinstance(colors, list): colors = [colors] if not isinstance(colors[0], (tuple, str)): raise ValueError("colors must be a tuple or a string, or a list thereof") if isinstance(colors[0], tuple) and len(colors[0]) != 3: raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") out_dtype = torch.uint8 colors_ = [] for color in colors: if isinstance(color, str): color = ImageColor.getrgb(color) color = torch.tensor(color, dtype=out_dtype) colors_.append(color) img_to_draw = image.detach().clone() # TODO: There might be a way to vectorize this for mask, color in zip(masks, colors_): img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha return out.to(out_dtype) def _generate_color_palette(num_masks): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) return [tuple((i * palette) % 255) for i in range(num_masks)]