Unverified Commit 57a0c32e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce type failures in dispatchers (#6988)

* introduce type failures in dispatchers

* add type checks to all dispatchers

* add missing else

* add test

* fix convert_color_space
parent 0bd77df2
...@@ -422,6 +422,14 @@ class TestDispatchers: ...@@ -422,6 +422,14 @@ class TestDispatchers:
assert dispatcher_params == feature_params assert dispatcher_params == feature_params
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
unkown_input = object()
(_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")
with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
info.dispatcher(unkown_input, *other_args, **kwargs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("alias", "target"), ("alias", "target"),
......
...@@ -51,5 +51,10 @@ def erase( ...@@ -51,5 +51,10 @@ def erase(
elif isinstance(inpt, features.Video): elif isinstance(inpt, features.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output) return features.Video.wrap_like(inpt, output)
else: # isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
import PIL.Image
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision.prototype import features from torchvision.prototype import features
...@@ -41,8 +42,13 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> ...@@ -41,8 +42,13 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor) return inpt.adjust_brightness(brightness_factor=brightness_factor)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
...@@ -75,8 +81,13 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> ...@@ -75,8 +81,13 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor) return inpt.adjust_saturation(saturation_factor=saturation_factor)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
...@@ -109,8 +120,13 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat ...@@ -109,8 +120,13 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor) return inpt.adjust_contrast(contrast_factor=contrast_factor)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
...@@ -177,8 +193,13 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe ...@@ -177,8 +193,13 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
...@@ -284,8 +305,13 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input ...@@ -284,8 +305,13 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor) return inpt.adjust_hue(hue_factor=hue_factor)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor) return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
...@@ -319,8 +345,13 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> ...@@ -319,8 +345,13 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain) return inpt.adjust_gamma(gamma=gamma, gain=gain)
else: elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
...@@ -348,8 +379,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: ...@@ -348,8 +379,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits) return inpt.posterize(bits=bits)
else: elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits) return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
...@@ -371,8 +407,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp ...@@ -371,8 +407,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold) return inpt.solarize(threshold=threshold)
else: elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold) return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
...@@ -416,8 +457,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -416,8 +457,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.autocontrast() return inpt.autocontrast()
else: elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt) return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
...@@ -501,8 +547,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -501,8 +547,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.equalize() return inpt.equalize()
else: elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt) return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
...@@ -527,5 +578,10 @@ def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -527,5 +578,10 @@ def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.invert() return inpt.invert()
else: elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt) return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
...@@ -59,8 +59,13 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -59,8 +59,13 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.horizontal_flip() return inpt.horizontal_flip()
else: elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt) return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
...@@ -100,8 +105,13 @@ def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -100,8 +105,13 @@ def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return vertical_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.vertical_flip() return inpt.vertical_flip()
else: elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt) return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...@@ -221,10 +231,15 @@ def resize( ...@@ -221,10 +231,15 @@ def resize(
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
else: elif isinstance(inpt, PIL.Image.Image):
if antialias is not None and not antialias: if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _affine_parse_args( def _affine_parse_args(
...@@ -725,7 +740,7 @@ def affine( ...@@ -725,7 +740,7 @@ def affine(
return inpt.affine( return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
) )
else: elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil( return affine_image_pil(
inpt, inpt,
angle, angle,
...@@ -736,6 +751,11 @@ def affine( ...@@ -736,6 +751,11 @@ def affine(
fill=fill, fill=fill,
center=center, center=center,
) )
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def rotate_image_tensor( def rotate_image_tensor(
...@@ -889,8 +909,13 @@ def rotate( ...@@ -889,8 +909,13 @@ def rotate(
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else: elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
...@@ -1090,8 +1115,13 @@ def pad( ...@@ -1090,8 +1115,13 @@ def pad(
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode) return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
else: elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
...@@ -1159,8 +1189,13 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i ...@@ -1159,8 +1189,13 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i
return crop_image_tensor(inpt, top, left, height, width) return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width) return inpt.crop(top, left, height, width)
else: elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width) return crop_image_pil(inpt, top, left, height, width)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
...@@ -1411,10 +1446,15 @@ def perspective( ...@@ -1411,10 +1446,15 @@ def perspective(
return inpt.perspective( return inpt.perspective(
startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
) )
else: elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil( return perspective_image_pil(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
) )
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def elastic_image_tensor( def elastic_image_tensor(
...@@ -1560,8 +1600,13 @@ def elastic( ...@@ -1560,8 +1600,13 @@ def elastic(
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill) return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
else: elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
elastic_transform = elastic elastic_transform = elastic
...@@ -1665,8 +1710,13 @@ def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features ...@@ -1665,8 +1710,13 @@ def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features
return center_crop_image_tensor(inpt, output_size) return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.center_crop(output_size) return inpt.center_crop(output_size)
else: elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size) return center_crop_image_pil(inpt, output_size)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def resized_crop_image_tensor( def resized_crop_image_tensor(
...@@ -1753,8 +1803,13 @@ def resized_crop( ...@@ -1753,8 +1803,13 @@ def resized_crop(
) )
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
else: elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
...@@ -1831,8 +1886,13 @@ def five_crop( ...@@ -1831,8 +1886,13 @@ def five_crop(
elif isinstance(inpt, features.Video): elif isinstance(inpt, features.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size) output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
else: # isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size) return five_crop_image_pil(inpt, size)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
...@@ -1879,5 +1939,10 @@ def ten_crop( ...@@ -1879,5 +1939,10 @@ def ten_crop(
elif isinstance(inpt, features.Video): elif isinstance(inpt, features.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Video.wrap_like(inpt, item) for item in output] return [features.Video.wrap_like(inpt, item) for item in output]
else: # isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
...@@ -23,17 +23,22 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -23,17 +23,22 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
if isinstance(image, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
): ):
return get_dimensions_image_tensor(image) return get_dimensions_image_tensor(inpt)
elif isinstance(image, (features.Image, features.Video)): elif isinstance(inpt, (features.Image, features.Video)):
channels = image.num_channels channels = inpt.num_channels
height, width = image.spatial_size height, width = inpt.spatial_size
return [channels, height, width] return [channels, height, width]
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else: else:
return get_dimensions_image_pil(image) raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_channels_image_tensor(image: torch.Tensor) -> int: def get_num_channels_image_tensor(image: torch.Tensor) -> int:
...@@ -54,15 +59,20 @@ def get_num_channels_video(video: torch.Tensor) -> int: ...@@ -54,15 +59,20 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video) return get_num_channels_image_tensor(video)
def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int: def get_num_channels(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
if isinstance(image, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
): ):
return get_num_channels_image_tensor(image) return get_num_channels_image_tensor(inpt)
elif isinstance(image, (features.Image, features.Video)): elif isinstance(inpt, (features.Image, features.Video)):
return image.num_channels return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else: else:
return get_num_channels_image_pil(image) raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...@@ -103,8 +113,13 @@ def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: ...@@ -103,8 +113,13 @@ def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
return get_spatial_size_image_tensor(inpt) return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
return list(inpt.spatial_size) return list(inpt.spatial_size)
else: elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames_video(video: torch.Tensor) -> int:
...@@ -117,7 +132,9 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int: ...@@ -117,7 +132,9 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int:
elif isinstance(inpt, features.Video): elif isinstance(inpt, features.Video):
return inpt.num_frames return inpt.num_frames
else: else:
raise TypeError(f"The video should be a Tensor. Got {type(inpt)}") raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
...@@ -315,8 +332,13 @@ def convert_color_space( ...@@ -315,8 +332,13 @@ def convert_color_space(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
) )
return features.Video.wrap_like(inpt, output, color_space=color_space) return features.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space)
else: else:
return convert_color_space_image_pil(inpt, color_space) raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int: def _num_value_bits(dtype: torch.dtype) -> int:
...@@ -402,6 +424,11 @@ def convert_dtype( ...@@ -402,6 +424,11 @@ def convert_dtype(
elif isinstance(inpt, features.Image): elif isinstance(inpt, features.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output) return features.Image.wrap_like(inpt, output)
else: # isinstance(inpt, features.Video): elif isinstance(inpt, features.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output) return features.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
f"but got {type(inpt)} instead."
)
...@@ -53,13 +53,14 @@ def normalize( ...@@ -53,13 +53,14 @@ def normalize(
std: List[float], std: List[float],
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if not torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor) if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)):
else: inpt = inpt.as_subclass(torch.Tensor)
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)) else:
inpt = inpt.as_subclass(torch.Tensor) raise TypeError(
if not correct_type: f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") f"but got {type(inpt)} instead."
)
# Image or Video type should not be retained after normalization due to unknown data range # Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image # Thus we return Tensor for input Image
...@@ -168,5 +169,10 @@ def gaussian_blur( ...@@ -168,5 +169,10 @@ def gaussian_blur(
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
else: elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
...@@ -15,10 +15,14 @@ def uniform_temporal_subsample( ...@@ -15,10 +15,14 @@ def uniform_temporal_subsample(
) -> features.VideoTypeJIT: ) -> features.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
else: # isinstance(inpt, features.Video) elif isinstance(inpt, features.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4") raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video( output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
) )
return features.Video.wrap_like(inpt, output) return features.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
...@@ -8,13 +8,15 @@ from torchvision.transforms import functional as _F ...@@ -8,13 +8,15 @@ from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image: def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(image, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(image).permute((2, 0, 1)).contiguous() output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(image, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = pil_to_tensor(image) output = pil_to_tensor(inpt)
else: # isinstance(inpt, torch.Tensor): elif isinstance(inpt, torch.Tensor):
output = image output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return features.Image(output) return features.Image(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment