Unverified Commit 4df1a85c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Remove `_FT` aliases from functional (#6983)

* Remove `_FT` usages from functional

* Update error messages
parent 50b77fa7
......@@ -4,10 +4,17 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase
def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
if not inplace:
image = image.clone()
image[..., i : i + h, j : j + w] = v
return image
@torch.jit.unused
......
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
......@@ -9,7 +10,7 @@ from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = _FT._max_value(image1.dtype)
bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)
......@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_FT._assert_channels(image, [1, 3])
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)
......@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour
return image
......@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
......@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2:
return image
bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape
......@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour
return image
......@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _FT._max_value(image.dtype):
if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image)
......@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if image.numel() == 0:
# exit earlier on empty images
return image
bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
fp = image.is_floating_point()
float_image = image if fp else image.to(torch.float32)
......
......@@ -8,7 +8,7 @@ import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs,
......@@ -17,10 +17,15 @@ from torchvision.transforms.functional import (
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _pad_symmetric
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
horizontal_flip_image_tensor = _FT.hflip
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
horizontal_flip_image_pil = _FP.hflip
......@@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return horizontal_flip_image_pil(inpt)
vertical_flip_image_tensor = _FT.vflip
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)
vertical_flip_image_pil = _FP.vflip
......@@ -975,7 +983,7 @@ def _pad_with_scalar_fill(
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
image = _pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:]
......
......@@ -4,7 +4,8 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
......@@ -193,7 +194,7 @@ def clamp_bounding_box(
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
if not torch.all(alpha == _max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
......@@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
......@@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype))
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment