Unverified Commit d7e5b6a1 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors (#7113)

parent c06d52b1
...@@ -426,7 +426,6 @@ DISPATCHER_INFOS = [ ...@@ -426,7 +426,6 @@ DISPATCHER_INFOS = [
datapoints.Video: F.normalize_video, datapoints.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature,
xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"), xfail_jit_python_scalar_arg("std"),
], ],
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
import torchvision.prototype.transforms.utils import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, ...@@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
[ [
......
...@@ -289,6 +289,10 @@ class Image(Datapoint): ...@@ -289,6 +289,10 @@ class Image(Datapoint):
) )
return Image.wrap_like(self, output) return Image.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Image.wrap_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image] ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor ImageTypeJIT = torch.Tensor
......
...@@ -241,6 +241,10 @@ class Video(Datapoint): ...@@ -241,6 +241,10 @@ class Video(Datapoint):
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output) return Video.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Video.wrap_like(self, output)
VideoType = Union[torch.Tensor, Video] VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor VideoTypeJIT = torch.Tensor
......
...@@ -82,6 +82,7 @@ class ColorJitter(Transform): ...@@ -82,6 +82,7 @@ class ColorJitter(Transform):
return output return output
# TODO: This class seems to be untested
class RandomPhotometricDistort(Transform): class RandomPhotometricDistort(Transform):
_transformed_types = ( _transformed_types = (
datapoints.Image, datapoints.Image,
...@@ -119,15 +120,14 @@ class RandomPhotometricDistort(Transform): ...@@ -119,15 +120,14 @@ class RandomPhotometricDistort(Transform):
def _permute_channels( def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image):
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :] output = inpt[..., permutation, :, :]
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(orig_inpt, PIL.Image.Image):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
return output return output
......
...@@ -60,18 +60,14 @@ def normalize( ...@@ -60,18 +60,14 @@ def normalize(
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(normalize) _log_api_usage_once(normalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if isinstance(inpt, (datapoints.Image, datapoints.Video)): return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
inpt = inpt.as_subclass(torch.Tensor) elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
elif not is_simple_tensor(inpt): return inpt.normalize(mean=mean, std=std, inplace=inplace)
raise TypeError( else:
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " raise TypeError(
f"but got {type(inpt)} instead." f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
) )
# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment