Unverified Commit 79098ad9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[proto] Fix kernel passthrough and types of Normalize (#6490)

* Fix pass-through and supported types of Normalize

* update error message on kernel

* Fix linter.

* Fix the tests.

* Update type.

* Update type.

* Remove unnecessary tests for bboxes and masks.
parent 35ee1dd8
......@@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
def test_midlevel_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_segmentation_mask()
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.SegmentationMask)
torch.testing.assert_close(inpt, output)
inpt = make_bounding_box(format="XYXY")
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.BoundingBox)
torch.testing.assert_close(inpt, output)
inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)
......
import functools
from typing import Any, Callable, Dict, List, Sequence, Type, Union
from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image
......@@ -10,6 +10,8 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.transforms.transforms import _setup_size
from ._utils import is_simple_tensor
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......@@ -91,10 +93,12 @@ class LinearTransformation(Transform):
class Normalize(Transform):
def __init__(self, mean: List[float], std: List[float]):
_transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__()
self.mean = mean
self.std = std
self.mean = list(mean)
self.std = list(std)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std)
......
......@@ -14,11 +14,11 @@ DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
normalize_image_tensor = _FT.normalize
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
def normalize(
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
) -> DType:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else:
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment