utils.py 4.61 KB
Newer Older
1
2
import torch
import math
3
irange = range
4

5

6
def make_grid(tensor, nrow=8, padding=2,
7
              normalize=False, range=None, scale_each=False, pad_value=0):
8
    """Make a grid of images.
9

10
11
12
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
13
        nrow (int, optional): Number of images displayed in each row of the grid.
Tongzhou Wang's avatar
Tongzhou Wang committed
14
15
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
16
        normalize (bool, optional): If True, shift the image to the range (0, 1),
Tongzhou Wang's avatar
Tongzhou Wang committed
17
            by the min and max values specified by :attr:`range`. Default: ``False``.
18
19
20
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
Tongzhou Wang's avatar
Tongzhou Wang committed
21
22
23
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.
24

25
26
    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
27

28
    """
29
30
31
32
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

33
    # if list of tensors, convert to a 4D mini-batch Tensor
34
    if isinstance(tensor, list):
35
        tensor = torch.stack(tensor, dim=0)
36

37
    if tensor.dim() == 2:  # single image H x W
38
        tensor = tensor.unsqueeze(0)
39
    if tensor.dim() == 3:  # single image
40
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
41
            tensor = torch.cat((tensor, tensor, tensor), 0)
42
        tensor = tensor.unsqueeze(0)
43

44
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
45
        tensor = torch.cat((tensor, tensor, tensor), 1)
46
47

    if normalize is True:
48
        tensor = tensor.clone()  # avoid modifying tensor in-place
49
50
51
52
53
54
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
55
            img.add_(-min).div_(max - min + 1e-5)
56
57
58
59
60

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
61
                norm_ip(t, float(t.min()), float(t.max()))
62
63
64
65
66
67
68

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

69
    if tensor.size(0) == 1:
70
        return tensor.squeeze(0)
71

72
73
74
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
75
    ymaps = int(math.ceil(float(nmaps) / xmaps))
76
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
77
78
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
79
    k = 0
80
81
    for y in irange(ymaps):
        for x in irange(xmaps):
82
83
            if k >= nmaps:
                break
84
85
            grid.narrow(1, y * height + padding, height - padding)\
                .narrow(2, x * width + padding, width - padding)\
86
87
88
89
90
                .copy_(tensor[k])
            k = k + 1
    return grid


91
92
def save_image(tensor, fp, nrow=8, padding=2,
               normalize=False, range=None, scale_each=False, pad_value=0, format=None):
93
94
95
96
97
    """Save a given Tensor into an image file.

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
98
99
100
        fp - A filename(string) or file object
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
101
        **kwargs: Other arguments are documented in ``make_grid``.
102
103
    """
    from PIL import Image
104
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
105
                     normalize=normalize, range=range, scale_each=scale_each)
106
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
107
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
108
    im = Image.fromarray(ndarr)
109
    im.save(fp, format=format)