Unverified Commit 79098ad9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[proto] Fix kernel passthrough and types of Normalize (#6490)

* Fix pass-through and supported types of Normalize

* update error message on kernel

* Fix linter.

* Fix the tests.

* Update type.

* Update type.

* Remove unnecessary tests for bboxes and masks.
parent 35ee1dd8
...@@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): ...@@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
def test_midlevel_normalize_output_type(): def test_midlevel_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32) inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
inpt = make_segmentation_mask()
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.SegmentationMask)
torch.testing.assert_close(inpt, output)
inpt = make_bounding_box(format="XYXY")
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.BoundingBox)
torch.testing.assert_close(inpt, output)
inpt = make_image(color_space=features.ColorSpace.RGB) inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
......
import functools import functools
from typing import Any, Callable, Dict, List, Sequence, Type, Union from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -10,6 +10,8 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -10,6 +10,8 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import query_bounding_box from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.transforms.transforms import _setup_size from torchvision.transforms.transforms import _setup_size
from ._utils import is_simple_tensor
class Identity(Transform): class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
...@@ -91,10 +93,12 @@ class LinearTransformation(Transform): ...@@ -91,10 +93,12 @@ class LinearTransformation(Transform):
class Normalize(Transform): class Normalize(Transform):
def __init__(self, mean: List[float], std: List[float]): _transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__() super().__init__()
self.mean = mean self.mean = list(mean)
self.std = std self.std = list(std)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std) return F.normalize(inpt, mean=self.mean, std=self.std)
......
...@@ -14,11 +14,11 @@ DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] ...@@ -14,11 +14,11 @@ DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
normalize_image_tensor = _FT.normalize normalize_image_tensor = _FT.normalize
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType: def normalize(
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image): inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
return inpt ) -> DType:
elif isinstance(inpt, PIL.Image.Image): if not isinstance(inpt, torch.Tensor):
raise TypeError("Unsupported input type") raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else: else:
# Image instance after normalization is not Image anymore due to unknown data range # Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image # Thus we return Tensor for input Image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment