Unverified Commit a409c3fe authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed to_image_tensor and to_image_pil (#6454)

* [proto] Fixed to_image_tensor and to_image_pil

* Fixed failing test
parent 6de7021e
...@@ -1071,7 +1071,7 @@ class TestToImagePIL: ...@@ -1071,7 +1071,7 @@ class TestToImagePIL:
if inpt_type in (features.BoundingBox, str, int): if inpt_type in (features.BoundingBox, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, copy=transform.copy) fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToPILImage: class TestToPILImage:
......
...@@ -4,6 +4,7 @@ import math ...@@ -4,6 +4,7 @@ import math
import os import os
import numpy as np import numpy as np
import PIL.Image
import pytest import pytest
import torch.testing import torch.testing
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
...@@ -1861,3 +1862,49 @@ def test_midlevel_normalize_output_type(): ...@@ -1861,3 +1862,49 @@ def test_midlevel_normalize_output_type():
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32)),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("copy", [True, False])
def test_to_image_tensor(inpt, copy):
output = F.to_image_tensor(inpt, copy=copy)
assert isinstance(output, torch.Tensor)
assert np.asarray(inpt).sum() == output.sum().item()
if isinstance(inpt, PIL.Image.Image) and not copy:
# we can't check this option
# as PIL -> numpy is always copying
return
if isinstance(inpt, PIL.Image.Image):
inpt.putpixel((0, 0), 11)
else:
inpt[0, 0, 0] = 11
if copy:
assert output[0, 0, 0] != 11
else:
assert output[0, 0, 0] == 11
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
output = F.to_image_pil(inpt, mode=mode)
assert isinstance(output, PIL.Image.Image)
assert np.asarray(inpt).sum() == np.asarray(output).sum()
from typing import Any, Dict from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -63,12 +63,12 @@ class ToImagePIL(Transform): ...@@ -63,12 +63,12 @@ class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL # Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) _transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None: def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
self.copy = copy self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy) return F.to_image_pil(inpt, mode=self.mode)
else: else:
return inpt return inpt
import unittest.mock import unittest.mock
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -27,20 +27,25 @@ def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tenso ...@@ -27,20 +27,25 @@ def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tenso
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if copy: if copy:
return image.clone() return image.clone()
else: else:
return image return image
return _F.to_tensor(image) return _F.pil_to_tensor(image)
def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image: def to_image_pil(
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
if copy: if mode != image.mode:
return image.copy() return image.convert(mode)
else: else:
return image return image
return _F.to_pil_image(to_image_tensor(image, copy=False)) return _F.to_pil_image(image, mode=mode)
...@@ -173,7 +173,7 @@ def to_tensor(pic) -> Tensor: ...@@ -173,7 +173,7 @@ def to_tensor(pic) -> Tensor:
return img return img
def pil_to_tensor(pic): def pil_to_tensor(pic: Any) -> Tensor:
"""Convert a ``PIL Image`` to a tensor of the same type. """Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript. This function does not support torchscript.
...@@ -254,6 +254,7 @@ def to_pil_image(pic, mode=None): ...@@ -254,6 +254,7 @@ def to_pil_image(pic, mode=None):
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image) _log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment