Unverified Commit 1b44be35 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed Normalize._transform and added mid-level normalize (#6331)

* Added Image.normalize and fixed Normalize._transform

* Updated code for normalize, removed Image.normalize
parent 1d0786b0
......@@ -1839,3 +1839,25 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])
def test_midlevel_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_segmentation_mask()
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.SegmentationMask)
torch.testing.assert_close(inpt, output)
inpt = make_bounding_box(format="XYXY")
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.BoundingBox)
torch.testing.assert_close(inpt, output)
inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
from __future__ import annotations
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch
......@@ -84,10 +86,10 @@ class _Feature(torch.Tensor):
else:
return output
def horizontal_flip(self) -> Any:
def horizontal_flip(self) -> _Feature:
return self
def vertical_flip(self) -> Any:
def vertical_flip(self) -> _Feature:
return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
......@@ -98,13 +100,13 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Any:
) -> _Feature:
return self
def crop(self, top: int, left: int, height: int, width: int) -> Any:
def crop(self, top: int, left: int, height: int, width: int) -> _Feature:
return self
def center_crop(self, output_size: List[int]) -> Any:
def center_crop(self, output_size: List[int]) -> _Feature:
return self
def resized_crop(
......@@ -116,7 +118,7 @@ class _Feature(torch.Tensor):
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Any:
) -> _Feature:
return self
def pad(
......@@ -124,7 +126,7 @@ class _Feature(torch.Tensor):
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Any:
) -> _Feature:
return self
def rotate(
......@@ -134,7 +136,7 @@ class _Feature(torch.Tensor):
expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
) -> _Feature:
return self
def affine(
......@@ -146,7 +148,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
) -> _Feature:
return self
def perspective(
......@@ -154,7 +156,7 @@ class _Feature(torch.Tensor):
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
) -> _Feature:
return self
def elastic(
......@@ -162,41 +164,41 @@ class _Feature(torch.Tensor):
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
) -> _Feature:
return self
def adjust_brightness(self, brightness_factor: float) -> Any:
def adjust_brightness(self, brightness_factor: float) -> _Feature:
return self
def adjust_saturation(self, saturation_factor: float) -> Any:
def adjust_saturation(self, saturation_factor: float) -> _Feature:
return self
def adjust_contrast(self, contrast_factor: float) -> Any:
def adjust_contrast(self, contrast_factor: float) -> _Feature:
return self
def adjust_sharpness(self, sharpness_factor: float) -> Any:
def adjust_sharpness(self, sharpness_factor: float) -> _Feature:
return self
def adjust_hue(self, hue_factor: float) -> Any:
def adjust_hue(self, hue_factor: float) -> _Feature:
return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> Any:
def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature:
return self
def posterize(self, bits: int) -> Any:
def posterize(self, bits: int) -> _Feature:
return self
def solarize(self, threshold: float) -> Any:
def solarize(self, threshold: float) -> _Feature:
return self
def autocontrast(self) -> Any:
def autocontrast(self) -> _Feature:
return self
def equalize(self) -> Any:
def equalize(self) -> _Feature:
return self
def invert(self) -> Any:
def invert(self) -> _Feature:
return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Any:
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self
......@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, query_image
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
class RandomErasing(_RandomApplyTransform):
......@@ -86,7 +86,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, torch.Tensor)):
if is_simple_tensor(inpt) or isinstance(inpt, features.Image):
output = F.erase_image_tensor(inpt, **params)
if isinstance(inpt, features.Image):
return features.Image.new_like(inpt, output)
......
......@@ -23,6 +23,8 @@ class ToTensor(Transform):
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt)
else:
......@@ -54,6 +56,8 @@ class ToPILImage(Transform):
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode)
else:
......
......@@ -39,12 +39,7 @@ class Normalize(Transform):
self.std = std
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, torch.Tensor):
# We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the
# normalization transform is no longer a features.Image
return F.normalize_image_tensor(inpt, mean=self.mean, std=self.std)
else:
return inpt
return F.normalize(inpt, mean=self.mean, std=self.std)
class GaussianBlur(Transform):
......
......@@ -45,6 +45,8 @@ class ToImageTensor(Transform):
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
......@@ -58,6 +60,8 @@ class ToImagePIL(Transform):
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy)
else:
......
......@@ -103,7 +103,7 @@ from ._geometry import (
vertical_flip_image_tensor,
vertical_flip_segmentation_mask,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize_image_tensor
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
......
......@@ -15,12 +15,14 @@ normalize_image_tensor = _FT.normalize
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features.Image):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif type(inpt) == torch.Tensor:
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
else:
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
else:
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def gaussian_blur_image_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment