Commit 19a27a21 authored by Adam Lerer's avatar Adam Lerer
Browse files

fix bugs in save_image

parent 8560d58b
...@@ -16,9 +16,10 @@ def make_grid(tensor, nrow=8, padding=2): ...@@ -16,9 +16,10 @@ def make_grid(tensor, nrow=8, padding=2):
for i in range(numImages): for i in range(numImages):
tensor[i].copy_(tensorlist[i]) tensor[i].copy_(tensorlist[i])
if tensor.dim() == 2: # single image H x W if tensor.dim() == 2: # single image H x W
tensor = torch.view(1, tensor.size(0), tensor.size(1)) tensor = tensor.view(1, tensor.size(0), tensor.size(1))
tensor = torch.cat((tensor, tensor, tensor), 0)
if tensor.dim() == 3: # single image if tensor.dim() == 3: # single image
if tensor.size(0) == 1:
tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor return tensor
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1) tensor = torch.cat((tensor, tensor, tensor), 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment