Unverified Commit d7e5b6a1 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors (#7113)

parent c06d52b1
......@@ -426,7 +426,6 @@ DISPATCHER_INFOS = [
datapoints.Video: F.normalize_video,
},
test_marks=[
skip_dispatch_feature,
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
],
......
......@@ -13,7 +13,7 @@ import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
......@@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
@pytest.mark.parametrize(
"inpt",
[
......
......@@ -289,6 +289,10 @@ class Image(Datapoint):
)
return Image.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Image.wrap_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
......
......@@ -241,6 +241,10 @@ class Video(Datapoint):
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Video.wrap_like(self, output)
VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
......
......@@ -82,6 +82,7 @@ class ColorJitter(Transform):
return output
# TODO: This class seems to be untested
class RandomPhotometricDistort(Transform):
_transformed_types = (
datapoints.Image,
......@@ -119,15 +120,14 @@ class RandomPhotometricDistort(Transform):
def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image):
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image):
if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
......
......@@ -60,19 +60,15 @@ def normalize(
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
elif not is_simple_tensor(inpt):
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.normalize(mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead."
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment