Commit ff3f738e authored by Yashas Annadani's avatar Yashas Annadani Committed by Alykhan Tejani
Browse files

Normalize single images to make_grid (#360)

parent aafaa2a1
......@@ -39,7 +39,8 @@ def make_grid(tensor, nrow=8, padding=2,
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor
tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
......@@ -65,6 +66,9 @@ def make_grid(tensor, nrow=8, padding=2,
else:
norm_range(tensor, range)
if tensor.size(0) == 1:
return tensor.squeeze()
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment