Commit ef67fd92 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Adding File object option to utils.save_image (#1301)

* format param added

* lint fixes

* lint fixes

* lint fixes
parent 367e8514
...@@ -4,6 +4,9 @@ import tempfile ...@@ -4,6 +4,9 @@ import tempfile
import torch import torch
import torchvision.utils as utils import torchvision.utils as utils
import unittest import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -52,6 +55,26 @@ class Tester(unittest.TestCase): ...@@ -52,6 +55,26 @@ class Tester(unittest.TestCase):
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save' assert os.path.exists(f.name), 'The pixel image is not present after save'
def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'
def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -88,13 +88,16 @@ def make_grid(tensor, nrow=8, padding=2, ...@@ -88,13 +88,16 @@ def make_grid(tensor, nrow=8, padding=2,
return grid return grid
def save_image(tensor, filename, nrow=8, padding=2, def save_image(tensor, fp, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0): normalize=False, range=None, scale_each=False, pad_value=0, format=None):
"""Save a given Tensor into an image file. """Save a given Tensor into an image file.
Args: Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``. saves the tensor as a grid of images by calling ``make_grid``.
fp - A filename(string) or file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
from PIL import Image from PIL import Image
...@@ -103,4 +106,4 @@ def save_image(tensor, filename, nrow=8, padding=2, ...@@ -103,4 +106,4 @@ def save_image(tensor, filename, nrow=8, padding=2,
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(fp, format=format)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment