Unverified Commit 4df1a85c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Remove `_FT` aliases from functional (#6983)

* Remove `_FT` usages from functional

* Update error messages
parent 50b77fa7
...@@ -4,10 +4,17 @@ import PIL.Image ...@@ -4,10 +4,17 @@ import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase
def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
if not inplace:
image = image.clone()
image[..., i : i + h, j : j + w] = v
return image
@torch.jit.unused @torch.jit.unused
......
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
...@@ -9,7 +10,7 @@ from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor ...@@ -9,7 +10,7 @@ from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio) ratio = float(ratio)
fp = image1.is_floating_point() fp = image1.is_floating_point()
bound = _FT._max_value(image1.dtype) bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype) return output if fp else output.to(image1.dtype)
...@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float ...@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_FT._assert_channels(image, [1, 3]) c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point() fp = image.is_floating_point()
bound = _FT._max_value(image.dtype) bound = _max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound) output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype) return output if fp else output.to(image.dtype)
...@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float ...@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour if c == 1: # Match PIL behaviour
return image return image
...@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point() fp = image.is_floating_point()
if c == 3: if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False) grayscale_image = _rgb_to_gray(image, cast=False)
...@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2: if image.numel() == 0 or height <= 2 or width <= 2:
return image return image
bound = _FT._max_value(image.dtype) bound = _max_value(image.dtype)
fp = image.is_floating_point() fp = image.is_floating_point()
shape = image.shape shape = image.shape
...@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour if c == 1: # Match PIL behaviour
return image return image
...@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: ...@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _FT._max_value(image.dtype): if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image) return torch.where(image >= threshold, invert_image_tensor(image), image)
...@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp ...@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if image.numel() == 0: if image.numel() == 0:
# exit earlier on empty images # exit earlier on empty images
return image return image
bound = _FT._max_value(image.dtype) bound = _max_value(image.dtype)
fp = image.is_floating_point() fp = image.is_floating_point()
float_image = image if fp else image.to(torch.float32) float_image = image if fp else image.to(torch.float32)
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size, _compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs, _get_perspective_coeffs,
...@@ -17,10 +17,15 @@ from torchvision.transforms.functional import ( ...@@ -17,10 +17,15 @@ from torchvision.transforms.functional import (
pil_to_tensor, pil_to_tensor,
to_pil_image, to_pil_image,
) )
from torchvision.transforms.functional_tensor import _pad_symmetric
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
horizontal_flip_image_tensor = _FT.hflip
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return horizontal_flip_image_pil(inpt) return horizontal_flip_image_pil(inpt)
vertical_flip_image_tensor = _FT.vflip def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)
vertical_flip_image_pil = _FP.vflip vertical_flip_image_pil = _FP.vflip
...@@ -975,7 +983,7 @@ def _pad_with_scalar_fill( ...@@ -975,7 +983,7 @@ def _pad_with_scalar_fill(
if needs_cast: if needs_cast:
image = image.to(dtype) image = image.to(dtype)
else: # padding_mode == "symmetric" else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding) image = _pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:] new_height, new_width = image.shape[-2:]
......
...@@ -4,7 +4,8 @@ import PIL.Image ...@@ -4,7 +4,8 @@ import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
...@@ -193,7 +194,7 @@ def clamp_bounding_box( ...@@ -193,7 +194,7 @@ def clamp_bounding_box(
def _strip_alpha(image: torch.Tensor) -> torch.Tensor: def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _FT._max_value(alpha.dtype)): if not torch.all(alpha == _max_value(alpha.dtype)):
raise RuntimeError( raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported." "Stripping the alpha channel if it contains values other than the max value is not supported."
) )
...@@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor ...@@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor
if alpha is None: if alpha is None:
shape = list(image.shape) shape = list(image.shape)
shape[-3] = 1 shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device) alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3) return torch.cat((image, alpha), dim=-3)
...@@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f ...@@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
# Instead, we can also multiply by the maximum value plus something close to `1`. See # Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3 eps = 1e-3
max_value = float(_FT._max_value(dtype)) max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`. # discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype) return image.mul(max_value + 1.0 - eps).to(dtype)
else: else:
# int to float # int to float
if float_output: if float_output:
return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype)) return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int # int to int
num_value_bits_input = _num_value_bits(image.dtype) num_value_bits_input = _num_value_bits(image.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment