Commit 7b18556c authored by Francisco Massa's avatar Francisco Massa Committed by Soumith Chintala
Browse files

Add asserts to make_grid and avoid inplace modification (#241)

parent 8e375670
import torch
import torchvision.utils as utils
import unittest
class Tester(unittest.TestCase):
def test_make_grid_not_inplace(self):
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()
utils.make_grid(t, normalize=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
utils.make_grid(t, normalize=True, scale_each=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
utils.make_grid(t, normalize=True, scale_each=True)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
def test_make_grid_raises_with_variable(self):
t = torch.autograd.Variable(torch.rand(3, 10, 10))
with self.assertRaises(TypeError):
utils.make_grid(t)
with self.assertRaises(TypeError):
utils.make_grid([t, t, t, t])
if __name__ == '__main__':
unittest.main()
......@@ -26,14 +26,13 @@ def make_grid(tensor, nrow=8, padding=2,
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensorlist = tensor
numImages = len(tensorlist)
size = torch.Size(torch.Size([numImages]) + tensorlist[0].size())
tensor = tensorlist[0].new(size)
for i in irange(numImages):
tensor[i].copy_(tensorlist[i])
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
......@@ -45,6 +44,7 @@ def make_grid(tensor, nrow=8, padding=2,
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment