utils.py 1.97 KB
Newer Older
1
2
import torch
import math
3

4

5
6
7
def make_grid(tensor, nrow=8, padding=2):
    """
    Given a 4D mini-batch Tensor of shape (B x C x H x W),
8
    or a list of images all of the same size,
9
10
    makes a grid of images
    """
11
12
13
14
15
16
17
18
    tensorlist = None
    if isinstance(tensor, list):
        tensorlist = tensor
        numImages = len(tensorlist)
        size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
        tensor = tensorlist[0].new(size)
        for i in range(numImages):
            tensor[i].copy_(tensorlist[i])
19
    if tensor.dim() == 2:  # single image H x W
Adam Lerer's avatar
Adam Lerer committed
20
        tensor = tensor.view(1, tensor.size(0), tensor.size(1))
21
    if tensor.dim() == 3:  # single image
Adam Lerer's avatar
Adam Lerer committed
22
23
        if tensor.size(0) == 1:
            tensor = torch.cat((tensor, tensor, tensor), 0)
24
        return tensor
25
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
26
        tensor = torch.cat((tensor, tensor, tensor), 1)
27
28
29
    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
30
    ymaps = int(math.ceil(float(nmaps) / xmaps))
31
32
33
34
35
36
37
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
    grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max())
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
38
39
            grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
                .narrow(2, x * width + 1 + padding // 2, width - padding)\
40
41
42
43
44
45
46
47
48
49
50
51
52
                .copy_(tensor[k])
            k = k + 1
    return grid


def save_image(tensor, filename, nrow=8, padding=2):
    """
    Saves a given Tensor into an image file.
    If given a mini-batch tensor, will save the tensor as a grid of images.
    """
    from PIL import Image
    tensor = tensor.cpu()
    grid = make_grid(tensor, nrow=nrow, padding=padding)
Soumith Chintala's avatar
Soumith Chintala committed
53
    ndarr = grid.mul(255).byte().transpose(0, 2).transpose(0, 1).numpy()
54
55
    im = Image.fromarray(ndarr)
    im.save(filename)