Unverified Commit 6746986d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

More prototype Transforms cleanups (#6502)

* `to_image_tensor` returns `feature.Image`

* Normalize filters PIL on forward

* decode_image_with_pil returns feature.Image

* Remove sample unpacking from Normalize

* Removing debug method that cause mypy to complain

* adding back helpful comment

* undo change on normalize kernel to maintain the helpful error message to users who use the kernel directly

* unused import
parent 5737ed27
......@@ -9,7 +9,6 @@ import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from ._feature import _Feature
from ._image import Image
D = TypeVar("D", bound="EncodedData")
......@@ -46,11 +45,6 @@ class EncodedImage(EncodedData):
return self._image_size
def decode(self) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image(self._F.decode_image_with_pil(self))
class EncodedVideo(EncodedData):
pass
......@@ -94,7 +94,7 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor)
_transformed_types = (features.Image, is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__()
......@@ -104,6 +104,11 @@ class Normalize(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std)
def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
return super().forward(*inpts)
class GaussianBlur(Transform):
def __init__(
......
......@@ -14,8 +14,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.decode_image_with_pil(inpt)
return features.Image(output)
return F.decode_image_with_pil(inpt)
class LabelToOneHot(Transform):
......@@ -43,8 +42,7 @@ class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.to_image_tensor(inpt)
return features.Image(output)
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
......
......@@ -16,7 +16,7 @@ normalize_image_tensor = _FT.normalize
def normalize(
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
) -> DType:
) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else:
......
......@@ -5,15 +5,16 @@ import numpy as np
import PIL.Image
import torch
from torchvision.io.video import read_video
from torchvision.prototype import features
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F
def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True))
if image.ndim == 2:
image = image.unsqueeze(2)
return image.permute(2, 0, 1)
return features.Image(image.permute(2, 0, 1))
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
......@@ -21,11 +22,12 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(image, np.ndarray):
return torch.from_numpy(image)
return _F.pil_to_tensor(image)
output = torch.from_numpy(image)
else:
output = _F.pil_to_tensor(image)
return features.Image(output)
to_image_pil = _F.to_pil_image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment