You need to sign in or sign up before continuing.
Commit ef67fd92 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Adding File object option to utils.save_image (#1301)

* format param added

* lint fixes

* lint fixes

* lint fixes
parent 367e8514
...@@ -4,6 +4,9 @@ import tempfile ...@@ -4,6 +4,9 @@ import tempfile
import torch import torch
import torchvision.utils as utils import torchvision.utils as utils
import unittest import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -52,6 +55,26 @@ class Tester(unittest.TestCase): ...@@ -52,6 +55,26 @@ class Tester(unittest.TestCase):
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save' assert os.path.exists(f.name), 'The pixel image is not present after save'
def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'
def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -88,13 +88,16 @@ def make_grid(tensor, nrow=8, padding=2, ...@@ -88,13 +88,16 @@ def make_grid(tensor, nrow=8, padding=2,
return grid return grid
def save_image(tensor, filename, nrow=8, padding=2, def save_image(tensor, fp, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0): normalize=False, range=None, scale_each=False, pad_value=0, format=None):
"""Save a given Tensor into an image file. """Save a given Tensor into an image file.
Args: Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``. saves the tensor as a grid of images by calling ``make_grid``.
fp - A filename(string) or file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
from PIL import Image from PIL import Image
...@@ -103,4 +106,4 @@ def save_image(tensor, filename, nrow=8, padding=2, ...@@ -103,4 +106,4 @@ def save_image(tensor, filename, nrow=8, padding=2,
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(fp, format=format)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment