Commit c4f4c73a authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

add range to make_grid (#99)

parent 989d52a0
...@@ -349,15 +349,30 @@ For example: ...@@ -349,15 +349,30 @@ For example:
Utils Utils
===== =====
make\_grid(tensor, nrow=8, padding=2) make\_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale\_each=False)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Given a 4D mini-batch Tensor of shape (B x C x H x W), makes a grid of Given a 4D mini-batch Tensor of shape (B x C x H x W),
images or a list of images all of the same size,
makes a grid of images
save\_image(tensor, filename, nrow=8, padding=2) normalize=True will shift the image to the range (0, 1),
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ by subtracting the minimum and dividing by the maximum pixel value.
if range=(min, max) where min and max are numbers, then these numbers are used to
normalize the image.
scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
save\_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale\_each=False)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Saves a given Tensor into an image file. Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images. If given a mini-batch tensor, will save the tensor as a grid of images.
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
more details
import torch import torch
import math import math
irange = range
def make_grid(tensor, nrow=8, padding=2): def make_grid(tensor, nrow=8, padding=2,
normalize=False, range=None, scale_each=False):
""" """
Given a 4D mini-batch Tensor of shape (B x C x H x W), Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size, or a list of images all of the same size,
makes a grid of images makes a grid of images
normalize=True will shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value.
if range=(min, max) where min and max are numbers, then these numbers are used to
normalize the image.
scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
""" """
tensorlist = None # if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list): if isinstance(tensor, list):
tensorlist = tensor tensorlist = tensor
numImages = len(tensorlist) numImages = len(tensorlist)
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size()) size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
tensor = tensorlist[0].new(size) tensor = tensorlist[0].new(size)
for i in range(numImages): for i in irange(numImages):
tensor[i].copy_(tensorlist[i]) tensor[i].copy_(tensorlist[i])
if tensor.dim() == 2: # single image H x W if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1)) tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3: # single image if tensor.dim() == 3: # single image
if tensor.size(0) == 1: if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0) tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor return tensor
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1) tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min)
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, t.min(), t.max())
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
else:
norm_range(tensor, range)
# make the mini-batch of images into a grid # make the mini-batch of images into a grid
nmaps = tensor.size(0) nmaps = tensor.size(0)
xmaps = min(nrow, nmaps) xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps)) ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max()) grid = tensor.new(3, height * ymaps, width * xmaps).fill_(0)
k = 0 k = 0
for y in range(ymaps): for y in irange(ymaps):
for x in range(xmaps): for x in irange(xmaps):
if k >= nmaps: if k >= nmaps:
break break
grid.narrow(1, y * height + 1 + padding // 2, height - padding)\ grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
...@@ -42,14 +78,18 @@ def make_grid(tensor, nrow=8, padding=2): ...@@ -42,14 +78,18 @@ def make_grid(tensor, nrow=8, padding=2):
return grid return grid
def save_image(tensor, filename, nrow=8, padding=2): def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False):
""" """
Saves a given Tensor into an image file. Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images. If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
more details
""" """
from PIL import Image from PIL import Image
tensor = tensor.cpu() tensor = tensor.cpu()
grid = make_grid(tensor, nrow=nrow, padding=padding) grid = make_grid(tensor, nrow=nrow, padding=padding,
normalize=normalize, range=range, scale_each=scale_each)
ndarr = grid.mul(255).byte().transpose(0, 2).transpose(0, 1).numpy() ndarr = grid.mul(255).byte().transpose(0, 2).transpose(0, 1).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment