utils.py 5.23 KB
Newer Older
1
from typing import Union, Optional, List, Tuple, Text, BinaryIO
2
3
import io
import pathlib
4
5
import torch
import math
6
irange = range
7

8

9
def make_grid(
10
    tensor: Union[torch.Tensor, List[torch.Tensor]],
11
12
13
14
15
16
17
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: int = 0,
) -> torch.Tensor:
18
    """Make a grid of images.
19

20
21
22
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
23
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
24
25
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
26
        normalize (bool, optional): If True, shift the image to the range (0, 1),
Tongzhou Wang's avatar
Tongzhou Wang committed
27
            by the min and max values specified by :attr:`range`. Default: ``False``.
28
29
30
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
31
32
33
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
34

35
36
    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
37

38
    """
39
40
41
42
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

43
    # if list of tensors, convert to a 4D mini-batch Tensor
44
    if isinstance(tensor, list):
45
        tensor = torch.stack(tensor, dim=0)
46

47
    if tensor.dim() == 2:  # single image H x W
48
        tensor = tensor.unsqueeze(0)
49
    if tensor.dim() == 3:  # single image
50
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
51
            tensor = torch.cat((tensor, tensor, tensor), 0)
52
        tensor = tensor.unsqueeze(0)
53

54
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
55
        tensor = torch.cat((tensor, tensor, tensor), 1)
56
57

    if normalize is True:
58
        tensor = tensor.clone()  # avoid modifying tensor in-place
59
60
61
62
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

63
64
65
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))
66
67
68
69
70

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
71
                norm_ip(t, float(t.min()), float(t.max()))
72
73
74
75
76
77
78

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

79
    if tensor.size(0) == 1:
80
        return tensor.squeeze(0)
81

82
83
84
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
85
    ymaps = int(math.ceil(float(nmaps) / xmaps))
86
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
87
88
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
89
    k = 0
90
91
    for y in irange(ymaps):
        for x in irange(xmaps):
92
93
            if k >= nmaps:
                break
94
95
96
97
98
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                2, x * width + padding, width - padding
            ).copy_(tensor[k])
99
100
101
102
            k = k + 1
    return grid


103
def save_image(
104
    tensor: Union[torch.Tensor, List[torch.Tensor]],
105
106
107
108
109
110
111
112
113
    fp: Union[Text, pathlib.Path, BinaryIO],
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: int = 0,
    format: Optional[str] = None,
) -> None:
114
115
116
117
118
    """Save a given Tensor into an image file.

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
119
        fp (string or file object): A filename or a file object
120
121
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
122
        **kwargs: Other arguments are documented in ``make_grid``.
123
124
    """
    from PIL import Image
125
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
126
                     normalize=normalize, range=range, scale_each=scale_each)
127
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
128
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
129
    im = Image.fromarray(ndarr)
130
    im.save(fp, format=format)