Unverified Commit 1b44be35 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed Normalize._transform and added mid-level normalize (#6331)

* Added Image.normalize and fixed Normalize._transform

* Updated code for normalize, removed Image.normalize
parent 1d0786b0
...@@ -1839,3 +1839,25 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): ...@@ -1839,3 +1839,25 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1]) torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]]) torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1]) torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])
def test_midlevel_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_segmentation_mask()
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.SegmentationMask)
torch.testing.assert_close(inpt, output)
inpt = make_bounding_box(format="XYXY")
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.BoundingBox)
torch.testing.assert_close(inpt, output)
inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
from __future__ import annotations
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch import torch
...@@ -84,10 +86,10 @@ class _Feature(torch.Tensor): ...@@ -84,10 +86,10 @@ class _Feature(torch.Tensor):
else: else:
return output return output
def horizontal_flip(self) -> Any: def horizontal_flip(self) -> _Feature:
return self return self
def vertical_flip(self) -> Any: def vertical_flip(self) -> _Feature:
return self return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
...@@ -98,13 +100,13 @@ class _Feature(torch.Tensor): ...@@ -98,13 +100,13 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> Any: ) -> _Feature:
return self return self
def crop(self, top: int, left: int, height: int, width: int) -> Any: def crop(self, top: int, left: int, height: int, width: int) -> _Feature:
return self return self
def center_crop(self, output_size: List[int]) -> Any: def center_crop(self, output_size: List[int]) -> _Feature:
return self return self
def resized_crop( def resized_crop(
...@@ -116,7 +118,7 @@ class _Feature(torch.Tensor): ...@@ -116,7 +118,7 @@ class _Feature(torch.Tensor):
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False, antialias: bool = False,
) -> Any: ) -> _Feature:
return self return self
def pad( def pad(
...@@ -124,7 +126,7 @@ class _Feature(torch.Tensor): ...@@ -124,7 +126,7 @@ class _Feature(torch.Tensor):
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Any: ) -> _Feature:
return self return self
def rotate( def rotate(
...@@ -134,7 +136,7 @@ class _Feature(torch.Tensor): ...@@ -134,7 +136,7 @@ class _Feature(torch.Tensor):
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Any: ) -> _Feature:
return self return self
def affine( def affine(
...@@ -146,7 +148,7 @@ class _Feature(torch.Tensor): ...@@ -146,7 +148,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Any: ) -> _Feature:
return self return self
def perspective( def perspective(
...@@ -154,7 +156,7 @@ class _Feature(torch.Tensor): ...@@ -154,7 +156,7 @@ class _Feature(torch.Tensor):
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any: ) -> _Feature:
return self return self
def elastic( def elastic(
...@@ -162,41 +164,41 @@ class _Feature(torch.Tensor): ...@@ -162,41 +164,41 @@ class _Feature(torch.Tensor):
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any: ) -> _Feature:
return self return self
def adjust_brightness(self, brightness_factor: float) -> Any: def adjust_brightness(self, brightness_factor: float) -> _Feature:
return self return self
def adjust_saturation(self, saturation_factor: float) -> Any: def adjust_saturation(self, saturation_factor: float) -> _Feature:
return self return self
def adjust_contrast(self, contrast_factor: float) -> Any: def adjust_contrast(self, contrast_factor: float) -> _Feature:
return self return self
def adjust_sharpness(self, sharpness_factor: float) -> Any: def adjust_sharpness(self, sharpness_factor: float) -> _Feature:
return self return self
def adjust_hue(self, hue_factor: float) -> Any: def adjust_hue(self, hue_factor: float) -> _Feature:
return self return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> Any: def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature:
return self return self
def posterize(self, bits: int) -> Any: def posterize(self, bits: int) -> _Feature:
return self return self
def solarize(self, threshold: float) -> Any: def solarize(self, threshold: float) -> _Feature:
return self return self
def autocontrast(self) -> Any: def autocontrast(self) -> _Feature:
return self return self
def equalize(self) -> Any: def equalize(self) -> _Feature:
return self return self
def invert(self) -> Any: def invert(self) -> _Feature:
return self return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Any: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self return self
...@@ -9,7 +9,7 @@ from torchvision.prototype import features ...@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, query_image from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -86,7 +86,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -86,7 +86,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, torch.Tensor)): if is_simple_tensor(inpt) or isinstance(inpt, features.Image):
output = F.erase_image_tensor(inpt, **params) output = F.erase_image_tensor(inpt, **params)
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
return features.Image.new_like(inpt, output) return features.Image.new_like(inpt, output)
......
...@@ -23,6 +23,8 @@ class ToTensor(Transform): ...@@ -23,6 +23,8 @@ class ToTensor(Transform):
super().__init__() super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (PIL.Image.Image, np.ndarray)): if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
else: else:
...@@ -54,6 +56,8 @@ class ToPILImage(Transform): ...@@ -54,6 +56,8 @@ class ToPILImage(Transform):
self.mode = mode self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)): if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode) return _F.to_pil_image(inpt, mode=self.mode)
else: else:
......
...@@ -39,12 +39,7 @@ class Normalize(Transform): ...@@ -39,12 +39,7 @@ class Normalize(Transform):
self.std = std self.std = std
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, torch.Tensor): return F.normalize(inpt, mean=self.mean, std=self.std)
# We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the
# normalization transform is no longer a features.Image
return F.normalize_image_tensor(inpt, mean=self.mean, std=self.std)
else:
return inpt
class GaussianBlur(Transform): class GaussianBlur(Transform):
......
...@@ -45,6 +45,8 @@ class ToImageTensor(Transform): ...@@ -45,6 +45,8 @@ class ToImageTensor(Transform):
self.copy = copy self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy) output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output) return features.Image(output)
...@@ -58,6 +60,8 @@ class ToImagePIL(Transform): ...@@ -58,6 +60,8 @@ class ToImagePIL(Transform):
self.copy = copy self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy) return F.to_image_pil(inpt, copy=self.copy)
else: else:
......
...@@ -103,7 +103,7 @@ from ._geometry import ( ...@@ -103,7 +103,7 @@ from ._geometry import (
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_segmentation_mask, vertical_flip_segmentation_mask,
) )
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize_image_tensor from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import ( from ._type_conversion import (
decode_image_with_pil, decode_image_with_pil,
decode_video_with_av, decode_video_with_av,
......
...@@ -15,12 +15,14 @@ normalize_image_tensor = _FT.normalize ...@@ -15,12 +15,14 @@ normalize_image_tensor = _FT.normalize
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType: def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features.Image): if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return inpt
elif type(inpt) == torch.Tensor: elif isinstance(inpt, PIL.Image.Image):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
else:
raise TypeError("Unsupported input type") raise TypeError("Unsupported input type")
else:
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment