utils.py 3.7 KB
Newer Older
1
2
import torch
import math
3
irange = range
4

5

6
def make_grid(tensor, nrow=8, padding=2,
7
              normalize=False, range=None, scale_each=False, pad_value=0):
8
9
    """
    Given a 4D mini-batch Tensor of shape (B x C x H x W),
10
    or a list of images all of the same size,
11
    makes a grid of images of size (B / nrow, nrow).
12
13
14
15
16
17
18
19
20
21

    normalize=True will shift the image to the range (0, 1),
    by subtracting the minimum and dividing by the maximum pixel value.

    if range=(min, max) where min and max are numbers, then these numbers are used to
    normalize the image.

    scale_each=True will scale each image in the batch of images separately rather than
    computing the (min, max) over all images.

22
23
    pad_value=<float> sets the value for the padded pixels.

24
    [Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
25
    """
26
    # if list of tensors, convert to a 4D mini-batch Tensor
27
28
29
    if isinstance(tensor, list):
        tensorlist = tensor
        numImages = len(tensorlist)
Adam Paszke's avatar
Adam Paszke committed
30
        size = torch.Size(torch.Size([numImages]) + tensorlist[0].size())
31
        tensor = tensorlist[0].new(size)
32
        for i in irange(numImages):
33
            tensor[i].copy_(tensorlist[i])
34

35
    if tensor.dim() == 2:  # single image H x W
Adam Lerer's avatar
Adam Lerer committed
36
        tensor = tensor.view(1, tensor.size(0), tensor.size(1))
37
    if tensor.dim() == 3:  # single image
38
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
Adam Lerer's avatar
Adam Lerer committed
39
            tensor = torch.cat((tensor, tensor, tensor), 0)
40
        return tensor
41
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
42
        tensor = torch.cat((tensor, tensor, tensor), 1)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    if normalize is True:
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, t.min(), t.max())

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

65
66
67
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
68
    ymaps = int(math.ceil(float(nmaps) / xmaps))
69
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
70
    grid = tensor.new(3, height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2).fill_(pad_value)
71
    k = 0
72
73
    for y in irange(ymaps):
        for x in irange(xmaps):
74
75
            if k >= nmaps:
                break
76
77
            grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
                .narrow(2, x * width + 1 + padding // 2, width - padding)\
78
79
80
81
82
                .copy_(tensor[k])
            k = k + 1
    return grid


83
def save_image(tensor, filename, nrow=8, padding=2,
84
               normalize=False, range=None, scale_each=False, pad_value=0):
85
86
    """
    Saves a given Tensor into an image file.
87
88
89
    If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
    All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
    more details
90
91
92
    """
    from PIL import Image
    tensor = tensor.cpu()
93
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
94
                     normalize=normalize, range=range, scale_each=scale_each)
95
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
96
97
    im = Image.fromarray(ndarr)
    im.save(filename)