Unverified Commit 9f6a189e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[proto] Use the proper `_transformed_types` in all Transforms and eliminate...

[proto] Use the proper `_transformed_types` in all Transforms and eliminate unnecessary dispatching (#6494)

* Update types in deprecated transforms.

* Update types in type conversion transforms.

* Fixing types in meta transforms.

* More changes on type conversion.

* Bug fix.

* Fix types

* Remove unnecessary conversions.

* Remove unnecessary import.

* Fixing tests

* Remove copy support from `to_image_tensor`

* restore test param

* Fix further tests
parent 368d1c6f
...@@ -1042,10 +1042,10 @@ class TestToImageTensor: ...@@ -1042,10 +1042,10 @@ class TestToImageTensor:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor() transform = transforms.ToImageTensor()
transform(inpt) transform(inpt)
if inpt_type in (features.BoundingBox, str, int): if inpt_type in (features.BoundingBox, features.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, copy=transform.copy) fn.assert_called_once_with(inpt)
class TestToImagePIL: class TestToImagePIL:
...@@ -1059,7 +1059,7 @@ class TestToImagePIL: ...@@ -1059,7 +1059,7 @@ class TestToImagePIL:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL() transform = transforms.ToImagePIL()
transform(inpt) transform(inpt)
if inpt_type in (features.BoundingBox, str, int): if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
......
...@@ -1867,31 +1867,23 @@ def test_midlevel_normalize_output_type(): ...@@ -1867,31 +1867,23 @@ def test_midlevel_normalize_output_type():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
[ [
torch.randint(0, 256, size=(3, 32, 32)),
127 * np.ones((32, 32, 3), dtype="uint8"), 127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122), PIL.Image.new("RGB", (32, 32), 122),
], ],
) )
@pytest.mark.parametrize("copy", [True, False]) def test_to_image_tensor(inpt):
def test_to_image_tensor(inpt, copy): output = F.to_image_tensor(inpt)
output = F.to_image_tensor(inpt, copy=copy)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert np.asarray(inpt).sum() == output.sum().item() assert np.asarray(inpt).sum() == output.sum().item()
if isinstance(inpt, PIL.Image.Image) and not copy: if isinstance(inpt, PIL.Image.Image):
# we can't check this option # we can't check this option
# as PIL -> numpy is always copying # as PIL -> numpy is always copying
return return
if isinstance(inpt, PIL.Image.Image): inpt[0, 0, 0] = 11
inpt.putpixel((0, 0), 11) assert output[0, 0, 0] == 11
else:
inpt[0, 0, 0] = 11
if copy:
assert output[0, 0, 0] != 11
else:
assert output[0, 0, 0] == 11
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1899,7 +1891,6 @@ def test_to_image_tensor(inpt, copy): ...@@ -1899,7 +1891,6 @@ def test_to_image_tensor(inpt, copy):
[ [
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8), torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
127 * np.ones((32, 32, 3), dtype="uint8"), 127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
], ],
) )
@pytest.mark.parametrize("mode", [None, "RGB"]) @pytest.mark.parametrize("mode", [None, "RGB"])
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional ...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace from torchvision.prototype.features import ColorSpace
...@@ -15,9 +16,7 @@ from ._utils import is_simple_tensor ...@@ -15,9 +16,7 @@ from ._utils import is_simple_tensor
class ToTensor(Transform): class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)
# Updated transformed types for ToTensor
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self) -> None: def __init__(self) -> None:
warnings.warn( warnings.warn(
...@@ -26,14 +25,13 @@ class ToTensor(Transform): ...@@ -26,14 +25,13 @@ class ToTensor(Transform):
) )
super().__init__() super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
if isinstance(inpt, (PIL.Image.Image, np.ndarray)): return _F.to_tensor(inpt)
return _F.to_tensor(inpt)
else:
return inpt
class PILToTensor(Transform): class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def __init__(self) -> None: def __init__(self) -> None:
warnings.warn( warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. " "The transform `PILToTensor()` is deprecated and will be removed in a future release. "
...@@ -41,17 +39,12 @@ class PILToTensor(Transform): ...@@ -41,17 +39,12 @@ class PILToTensor(Transform):
) )
super().__init__() super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
if isinstance(inpt, PIL.Image.Image): return _F.pil_to_tensor(inpt)
return _F.pil_to_tensor(inpt)
else:
return inpt
class ToPILImage(Transform): class ToPILImage(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
# Updated transformed types for ToPILImage
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn( warnings.warn(
...@@ -61,11 +54,8 @@ class ToPILImage(Transform): ...@@ -61,11 +54,8 @@ class ToPILImage(Transform):
super().__init__() super().__init__()
self.mode = mode self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)): return _F.to_pil_image(inpt, mode=self.mode)
return _F.to_pil_image(inpt, mode=self.mode)
else:
return inpt
class Grayscale(Transform): class Grayscale(Transform):
......
...@@ -11,6 +11,8 @@ from ._utils import is_simple_tensor ...@@ -11,6 +11,8 @@ from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
super().__init__() super().__init__()
if isinstance(format, str): if isinstance(format, str):
...@@ -18,30 +20,23 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -18,30 +20,23 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format self.format = format
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.BoundingBox): output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"]) return features.BoundingBox.new_like(inpt, output, format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
else:
return inpt
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
_transformed_types = (is_simple_tensor, features.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image): output = convert_image_dtype(inpt, dtype=self.dtype)
output = convert_image_dtype(inpt, dtype=self.dtype) return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
return features.Image.new_like(inpt, output, dtype=self.dtype)
elif is_simple_tensor(inpt):
return convert_image_dtype(inpt, dtype=self.dtype)
else:
return inpt
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image) _transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
def __init__( def __init__(
......
...@@ -11,12 +11,11 @@ from ._utils import is_simple_tensor ...@@ -11,12 +11,11 @@ from ._utils import is_simple_tensor
class DecodeImage(Transform): class DecodeImage(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: _transformed_types = (features.EncodedImage,)
if isinstance(inpt, features.EncodedImage):
output = F.decode_image_with_pil(inpt) def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
return features.Image(output) output = F.decode_image_with_pil(inpt)
else: return features.Image(output)
return inpt
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
...@@ -41,33 +40,19 @@ class LabelToOneHot(Transform): ...@@ -41,33 +40,19 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
# Updated transformed types for ToImageTensor def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) output = F.to_image_tensor(inpt)
return features.Image(output)
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
else:
return inpt
class ToImagePIL(Transform): class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
# Updated transformed types for ToImagePIL
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None: def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
self.mode = mode self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): return F.to_image_pil(inpt, mode=self.mode)
return F.to_image_pil(inpt, mode=self.mode)
else:
return inpt
import unittest.mock import unittest.mock
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor ...@@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = torch.from_numpy(image) return torch.from_numpy(image)
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image
return _F.pil_to_tensor(image) return _F.pil_to_tensor(image)
def to_image_pil( to_image_pil = _F.to_pil_image
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if mode != image.mode:
return image.convert(mode)
else:
return image
return _F.to_pil_image(image, mode=mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment