Unverified Commit a409c3fe authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed to_image_tensor and to_image_pil (#6454)

* [proto] Fixed to_image_tensor and to_image_pil

* Fixed failing test
parent 6de7021e
......@@ -1071,7 +1071,7 @@ class TestToImagePIL:
if inpt_type in (features.BoundingBox, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToPILImage:
......
......@@ -4,6 +4,7 @@ import math
import os
import numpy as np
import PIL.Image
import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F
......@@ -1861,3 +1862,49 @@ def test_midlevel_normalize_output_type():
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32)),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("copy", [True, False])
def test_to_image_tensor(inpt, copy):
output = F.to_image_tensor(inpt, copy=copy)
assert isinstance(output, torch.Tensor)
assert np.asarray(inpt).sum() == output.sum().item()
if isinstance(inpt, PIL.Image.Image) and not copy:
# we can't check this option
# as PIL -> numpy is always copying
return
if isinstance(inpt, PIL.Image.Image):
inpt.putpixel((0, 0), 11)
else:
inpt[0, 0, 0] = 11
if copy:
assert output[0, 0, 0] != 11
else:
assert output[0, 0, 0] == 11
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
output = F.to_image_pil(inpt, mode=mode)
assert isinstance(output, PIL.Image.Image)
assert np.asarray(inpt).sum() == np.asarray(output).sum()
from typing import Any, Dict
from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
......@@ -63,12 +63,12 @@ class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None:
def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
self.copy = copy
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy)
return F.to_image_pil(inpt, mode=self.mode)
else:
return inpt
import unittest.mock
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import PIL.Image
......@@ -27,20 +27,25 @@ def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tenso
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image
return _F.to_tensor(image)
return _F.pil_to_tensor(image)
def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image:
def to_image_pil(
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if copy:
return image.copy()
if mode != image.mode:
return image.convert(mode)
else:
return image
return _F.to_pil_image(to_image_tensor(image, copy=False))
return _F.to_pil_image(image, mode=mode)
......@@ -173,7 +173,7 @@ def to_tensor(pic) -> Tensor:
return img
def pil_to_tensor(pic):
def pil_to_tensor(pic: Any) -> Tensor:
"""Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript.
......@@ -254,6 +254,7 @@ def to_pil_image(pic, mode=None):
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment