Unverified Commit 57a0c32e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce type failures in dispatchers (#6988)

* introduce type failures in dispatchers

* add type checks to all dispatchers

* add missing else

* add test

* fix convert_color_space
parent 0bd77df2
......@@ -422,6 +422,14 @@ class TestDispatchers:
assert dispatcher_params == feature_params
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
unkown_input = object()
(_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")
with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
info.dispatcher(unkown_input, *other_args, **kwargs)
@pytest.mark.parametrize(
("alias", "target"),
......
......@@ -51,5 +51,10 @@ def erase(
elif isinstance(inpt, features.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output)
else: # isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
import PIL.Image
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
......@@ -41,8 +42,13 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
......@@ -75,8 +81,13 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
......@@ -109,8 +120,13 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
......@@ -177,8 +193,13 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
......@@ -284,8 +305,13 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
......@@ -319,8 +345,13 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
else:
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
......@@ -348,8 +379,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
else:
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
......@@ -371,8 +407,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
else:
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
......@@ -416,8 +457,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.autocontrast()
else:
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
......@@ -501,8 +547,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.equalize()
else:
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
......@@ -527,5 +578,10 @@ def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.invert()
else:
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -59,8 +59,13 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.horizontal_flip()
else:
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
......@@ -100,8 +105,13 @@ def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.vertical_flip()
else:
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
......@@ -221,10 +231,15 @@ def resize(
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, features._Feature):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
else:
elif isinstance(inpt, PIL.Image.Image):
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _affine_parse_args(
......@@ -725,7 +740,7 @@ def affine(
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
else:
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
......@@ -736,6 +751,11 @@ def affine(
fill=fill,
center=center,
)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def rotate_image_tensor(
......@@ -889,8 +909,13 @@ def rotate(
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
......@@ -1090,8 +1115,13 @@ def pad(
elif isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
else:
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......@@ -1159,8 +1189,13 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width)
else:
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
......@@ -1411,10 +1446,15 @@ def perspective(
return inpt.perspective(
startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else:
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def elastic_image_tensor(
......@@ -1560,8 +1600,13 @@ def elastic(
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
else:
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
elastic_transform = elastic
......@@ -1665,8 +1710,13 @@ def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature):
return inpt.center_crop(output_size)
else:
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def resized_crop_image_tensor(
......@@ -1753,8 +1803,13 @@ def resized_crop(
)
elif isinstance(inpt, features._Feature):
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
else:
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_five_crop_size(size: List[int]) -> List[int]:
......@@ -1831,8 +1886,13 @@ def five_crop(
elif isinstance(inpt, features.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
else: # isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
......@@ -1879,5 +1939,10 @@ def ten_crop(
elif isinstance(inpt, features.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Video.wrap_like(inpt, item) for item in output]
else: # isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -23,17 +23,22 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return get_dimensions_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels
height, width = image.spatial_size
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)):
channels = inpt.num_channels
height, width = inpt.spatial_size
return [channels, height, width]
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
return get_dimensions_image_pil(image)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
......@@ -54,15 +59,20 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
def get_num_channels(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return get_num_channels_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)):
return image.num_channels
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)):
return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
return get_num_channels_image_pil(image)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
......@@ -103,8 +113,13 @@ def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
return list(inpt.spatial_size)
else:
elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_frames_video(video: torch.Tensor) -> int:
......@@ -117,7 +132,9 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int:
elif isinstance(inpt, features.Video):
return inpt.num_frames
else:
raise TypeError(f"The video should be a Tensor. Got {type(inpt)}")
raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
......@@ -315,8 +332,13 @@ def convert_color_space(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return features.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space)
else:
return convert_color_space_image_pil(inpt, color_space)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
......@@ -402,6 +424,11 @@ def convert_dtype(
elif isinstance(inpt, features.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output)
else: # isinstance(inpt, features.Video):
elif isinstance(inpt, features.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
f"but got {type(inpt)} instead."
)
......@@ -53,13 +53,14 @@ def normalize(
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor)
else:
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video))
inpt = inpt.as_subclass(torch.Tensor)
if not correct_type:
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
if not torch.jit.is_scripting():
if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)):
inpt = inpt.as_subclass(torch.Tensor)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
f"but got {type(inpt)} instead."
)
# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
......@@ -168,5 +169,10 @@ def gaussian_blur(
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
else:
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -15,10 +15,14 @@ def uniform_temporal_subsample(
) -> features.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
else: # isinstance(inpt, features.Video)
elif isinstance(inpt, features.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
)
return features.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
......@@ -8,13 +8,15 @@ from torchvision.transforms import functional as _F
@torch.jit.unused
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(image, np.ndarray):
output = torch.from_numpy(image).permute((2, 0, 1)).contiguous()
elif isinstance(image, PIL.Image.Image):
output = pil_to_tensor(image)
else: # isinstance(inpt, torch.Tensor):
output = image
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image):
output = pil_to_tensor(inpt)
elif isinstance(inpt, torch.Tensor):
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return features.Image(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment