Unverified Commit 690a77fa authored by Hongbin Sun's avatar Hongbin Sun Committed by GitHub
Browse files

[Feature]: Support tensor2grayimgs (#1595)

* support tensor2grayimgs

* give default mean and std according to the input channel

* update docstring

* update

* fix bug
parent ac92a111
...@@ -9,18 +9,21 @@ except ImportError: ...@@ -9,18 +9,21 @@ except ImportError:
torch = None torch = None
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
"""Convert tensor to 3-channel images. """Convert tensor to 3-channel images or 1-channel gray images.
Args: Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape ( tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). N, C, H, W). :math:`C` can be either 3 or 1.
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0). mean (tuple[float], optional): Mean of images. If None,
std (tuple[float], optional): Standard deviation of images. (0, 0, 0) will be used for tensor with 3-channel,
Defaults to (1, 1, 1). while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_rgb (bool, optional): Whether the tensor was converted to RGB to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR. format in the first place. If so, convert it back to BGR.
Defaults to True. For the tensor with 1 channel, it must be False. Defaults to True.
Returns: Returns:
list[np.ndarray]: A list that contains multiple images. list[np.ndarray]: A list that contains multiple images.
...@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): ...@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
if torch is None: if torch is None:
raise RuntimeError('pytorch is not installed') raise RuntimeError('pytorch is not installed')
assert torch.is_tensor(tensor) and tensor.ndim == 4 assert torch.is_tensor(tensor) and tensor.ndim == 4
assert len(mean) == 3 channels = tensor.size(1)
assert len(std) == 3 assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_rgb)
num_imgs = tensor.size(0) num_imgs = tensor.size(0)
mean = np.array(mean, dtype=np.float32) mean = np.array(mean, dtype=np.float32)
......
...@@ -24,15 +24,29 @@ def test_tensor2imgs(): ...@@ -24,15 +24,29 @@ def test_tensor2imgs():
tensor = torch.randn(2, 3, 3) tensor = torch.randn(2, 3, 3)
mmcv.tensor2imgs(tensor) mmcv.tensor2imgs(tensor)
# test tensor dim-1
with pytest.raises(AssertionError):
tensor = torch.randn(2, 4, 3, 3)
mmcv.tensor2imgs(tensor)
# test mean length # test mean length
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
tensor = torch.randn(2, 3, 5, 5) tensor = torch.randn(2, 3, 5, 5)
mmcv.tensor2imgs(tensor, mean=(1, )) mmcv.tensor2imgs(tensor, mean=(1, ))
tensor = torch.randn(2, 1, 5, 5)
mmcv.tensor2imgs(tensor, mean=(0, 0, 0))
# test std length # test std length
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
tensor = torch.randn(2, 3, 5, 5) tensor = torch.randn(2, 3, 5, 5)
mmcv.tensor2imgs(tensor, std=(1, )) mmcv.tensor2imgs(tensor, std=(1, ))
tensor = torch.randn(2, 1, 5, 5)
mmcv.tensor2imgs(tensor, std=(1, 1, 1))
# test to_rgb
with pytest.raises(AssertionError):
tensor = torch.randn(2, 1, 5, 5)
mmcv.tensor2imgs(tensor, mean=(0, ), std=(1, ), to_rgb=True)
# test rgb=True # test rgb=True
tensor = torch.randn(2, 3, 5, 5) tensor = torch.randn(2, 3, 5, 5)
...@@ -50,3 +64,10 @@ def test_tensor2imgs(): ...@@ -50,3 +64,10 @@ def test_tensor2imgs():
outputs = mmcv.tensor2imgs(tensor, to_rgb=False) outputs = mmcv.tensor2imgs(tensor, to_rgb=False)
for gt, output in zip(gts, outputs): for gt, output in zip(gts, outputs):
assert_array_equal(gt, output) assert_array_equal(gt, output)
# test tensor channel 1 and rgb=False
tensor = torch.randn(2, 1, 5, 5)
gts = [t.squeeze(0).cpu().numpy().astype(np.uint8) for t in tensor]
outputs = mmcv.tensor2imgs(tensor, to_rgb=False)
for gt, output in zip(gts, outputs):
assert_array_equal(gt, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment