Unverified Commit 9f6a189e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[proto] Use the proper `_transformed_types` in all Transforms and eliminate...

[proto] Use the proper `_transformed_types` in all Transforms and eliminate unnecessary dispatching (#6494)

* Update types in deprecated transforms.

* Update types in type conversion transforms.

* Fixing types in meta transforms.

* More changes on type conversion.

* Bug fix.

* Fix types

* Remove unnecessary conversions.

* Remove unnecessary import.

* Fixing tests

* Remove copy support from `to_image_tensor`

* restore test param

* Fix further tests
parent 368d1c6f
......@@ -1042,10 +1042,10 @@ class TestToImageTensor:
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
if inpt_type in (features.BoundingBox, features.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
fn.assert_called_once_with(inpt)
class TestToImagePIL:
......@@ -1059,7 +1059,7 @@ class TestToImagePIL:
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
......
......@@ -1867,31 +1867,23 @@ def test_midlevel_normalize_output_type():
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32)),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("copy", [True, False])
def test_to_image_tensor(inpt, copy):
output = F.to_image_tensor(inpt, copy=copy)
def test_to_image_tensor(inpt):
output = F.to_image_tensor(inpt)
assert isinstance(output, torch.Tensor)
assert np.asarray(inpt).sum() == output.sum().item()
if isinstance(inpt, PIL.Image.Image) and not copy:
if isinstance(inpt, PIL.Image.Image):
# we can't check this option
# as PIL -> numpy is always copying
return
if isinstance(inpt, PIL.Image.Image):
inpt.putpixel((0, 0), 11)
else:
inpt[0, 0, 0] = 11
if copy:
assert output[0, 0, 0] != 11
else:
assert output[0, 0, 0] == 11
inpt[0, 0, 0] = 11
assert output[0, 0, 0] == 11
@pytest.mark.parametrize(
......@@ -1899,7 +1891,6 @@ def test_to_image_tensor(inpt, copy):
[
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
......
......@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
......@@ -15,9 +16,7 @@ from ._utils import is_simple_tensor
class ToTensor(Transform):
# Updated transformed types for ToTensor
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (PIL.Image.Image, np.ndarray)
def __init__(self) -> None:
warnings.warn(
......@@ -26,14 +25,13 @@ class ToTensor(Transform):
)
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
......@@ -41,17 +39,12 @@ class PILToTensor(Transform):
)
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, PIL.Image.Image):
return _F.pil_to_tensor(inpt)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.pil_to_tensor(inpt)
class ToPILImage(Transform):
# Updated transformed types for ToPILImage
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
......@@ -61,11 +54,8 @@ class ToPILImage(Transform):
super().__init__()
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
return _F.to_pil_image(inpt, mode=self.mode)
class Grayscale(Transform):
......
......@@ -11,6 +11,8 @@ from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
......@@ -18,30 +20,23 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.BoundingBox):
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
else:
return inpt
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
class ConvertImageDtype(Transform):
_transformed_types = (is_simple_tensor, features.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = convert_image_dtype(inpt, dtype=self.dtype)
return features.Image.new_like(inpt, output, dtype=self.dtype)
elif is_simple_tensor(inpt):
return convert_image_dtype(inpt, dtype=self.dtype)
else:
return inpt
output = convert_image_dtype(inpt, dtype=self.dtype)
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
......
......@@ -11,12 +11,11 @@ from ._utils import is_simple_tensor
class DecodeImage(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.EncodedImage):
output = F.decode_image_with_pil(inpt)
return features.Image(output)
else:
return inpt
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.decode_image_with_pil(inpt)
return features.Image(output)
class LabelToOneHot(Transform):
......@@ -41,33 +40,19 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
# Updated transformed types for ToImageTensor
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.to_image_tensor(inpt)
return features.Image(output)
class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, mode=self.mode)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
import unittest.mock
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
import numpy as np
import PIL.Image
......@@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image
return torch.from_numpy(image)
return _F.pil_to_tensor(image)
def to_image_pil(
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if mode != image.mode:
return image.convert(mode)
else:
return image
return _F.to_pil_image(image, mode=mode)
to_image_pil = _F.to_pil_image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment