utils.py 5.05 KB
Newer Older
1
2
3
from typing import Union, Optional, Sequence, Tuple, Text, BinaryIO
import io
import pathlib
4
5
import torch
import math
6
irange = range
7

8

9
10
11
12
def make_grid(tensor: Union[torch.Tensor, Sequence[torch.Tensor]], nrow: int = 8,
              padding: int = 2, normalize: bool = False,
              range: Optional[Tuple[int, int]] = None, scale_each: bool = False,
              pad_value: int = 0) -> torch.Tensor:
13
    """Make a grid of images.
14

15
16
17
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
18
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
19
20
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
21
        normalize (bool, optional): If True, shift the image to the range (0, 1),
Tongzhou Wang's avatar
Tongzhou Wang committed
22
            by the min and max values specified by :attr:`range`. Default: ``False``.
23
24
25
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
26
27
28
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
29

30
31
    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
32

33
    """
34
35
36
37
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

38
    # if list of tensors, convert to a 4D mini-batch Tensor
39
    if isinstance(tensor, list):
40
        tensor = torch.stack(tensor, dim=0)
41

42
    if tensor.dim() == 2:  # single image H x W
43
        tensor = tensor.unsqueeze(0)
44
    if tensor.dim() == 3:  # single image
45
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
46
            tensor = torch.cat((tensor, tensor, tensor), 0)
47
        tensor = tensor.unsqueeze(0)
48

49
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
50
        tensor = torch.cat((tensor, tensor, tensor), 1)
51
52

    if normalize is True:
53
        tensor = tensor.clone()  # avoid modifying tensor in-place
54
55
56
57
58
59
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
60
            img.add_(-min).div_(max - min + 1e-5)
61
62
63
64
65

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
66
                norm_ip(t, float(t.min()), float(t.max()))
67
68
69
70
71
72
73

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

74
    if tensor.size(0) == 1:
75
        return tensor.squeeze(0)
76

77
78
79
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
80
    ymaps = int(math.ceil(float(nmaps) / xmaps))
81
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
82
83
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
84
    k = 0
85
86
    for y in irange(ymaps):
        for x in irange(xmaps):
87
88
            if k >= nmaps:
                break
89
90
            grid.narrow(1, y * height + padding, height - padding)\
                .narrow(2, x * width + padding, width - padding)\
91
92
93
94
95
                .copy_(tensor[k])
            k = k + 1
    return grid


96
97
98
def save_image(tensor: Union[torch.Tensor, Sequence[torch.Tensor]], fp: Union[Text, pathlib.Path, BinaryIO],
               nrow: int = 8, padding: int = 2, normalize: bool = False, range: Optional[Tuple[int, int]] = None,
               scale_each: bool = False, pad_value: int = 0, format: Optional[str] = None) -> None:
99
100
101
102
103
    """Save a given Tensor into an image file.

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
104
        fp (string or file object): A filename or a file object
105
106
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
107
        **kwargs: Other arguments are documented in ``make_grid``.
108
109
    """
    from PIL import Image
110
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
111
                     normalize=normalize, range=range, scale_each=scale_each)
112
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
113
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
114
    im = Image.fromarray(ndarr)
115
    im.save(fp, format=format)