Unverified Commit 9c51928b authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add tensor2imgs in misc (#374)

* add tensor2imgs in misc

* rename

* fixed import

* raise error
parent 1ebd7ea6
...@@ -6,6 +6,7 @@ from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple, ...@@ -6,6 +6,7 @@ from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple,
imrescale, imresize, imresize_like, imrotate, imrescale, imresize, imresize_like, imrotate,
rescale_size) rescale_size)
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_, from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_,
posterize, solarize) posterize, solarize)
...@@ -16,5 +17,5 @@ __all__ = [ ...@@ -16,5 +17,5 @@ __all__ = [
'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread', 'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread',
'imwrite', 'supported_backends', 'use_backend', 'imdenormalize', 'imwrite', 'supported_backends', 'use_backend', 'imdenormalize',
'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize',
'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr' 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'tensor2imgs'
] ]
import numpy as np
import mmcv
try:
import torch
except ImportError:
torch = None
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
"""Convert tensor to 3-channel images
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W).
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
std (tuple[float], optional): Standard deviation of images.
Defaults to (1, 1, 1).
to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR.
Defaults to True.
Returns:
list[np.ndarray]: A list that contains multiple images.
"""
if torch is None:
raise RuntimeError('pytorch is not installed')
assert torch.is_tensor(tensor) and tensor.ndim == 4
assert len(mean) == 3
assert len(std) == 3
num_imgs = tensor.size(0)
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
imgs = []
for img_id in range(num_imgs):
img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
img = mmcv.imdenormalize(
img, mean, std, to_bgr=to_rgb).astype(np.uint8)
imgs.append(np.ascontiguousarray(img))
return imgs
# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import pytest
import torch
from numpy.testing import assert_array_equal
import mmcv
def test_tensor2imgs():
# test tensor obj
with pytest.raises(AssertionError):
tensor = np.random.rand(2, 3, 3)
mmcv.tensor2imgs(tensor)
# test tensor ndim
with pytest.raises(AssertionError):
tensor = torch.randn(2, 3, 3)
mmcv.tensor2imgs(tensor)
# test mean length
with pytest.raises(AssertionError):
tensor = torch.randn(2, 3, 5, 5)
mmcv.tensor2imgs(tensor, mean=(1, ))
# test std length
with pytest.raises(AssertionError):
tensor = torch.randn(2, 3, 5, 5)
mmcv.tensor2imgs(tensor, std=(1, ))
# test rgb=True
tensor = torch.randn(2, 3, 5, 5)
gts = [
t.cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
for t in tensor.flip(1)
]
outputs = mmcv.tensor2imgs(tensor, to_rgb=True)
for gt, output in zip(gts, outputs):
assert_array_equal(gt, output)
# test rgb=False
tensor = torch.randn(2, 3, 5, 5)
gts = [t.cpu().numpy().transpose(1, 2, 0).astype(np.uint8) for t in tensor]
outputs = mmcv.tensor2imgs(tensor, to_rgb=False)
for gt, output in zip(gts, outputs):
assert_array_equal(gt, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment