Unverified Commit 55d3ba62 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Deprecate functional_pil and functional_tensor and make them private (#7269)

parent 3991ab99
...@@ -17,7 +17,7 @@ from datasets_utils import combinations_grid ...@@ -17,7 +17,7 @@ from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
__all__ = [ __all__ = [
......
...@@ -29,7 +29,7 @@ from prototype_common_utils import ( ...@@ -29,7 +29,7 @@ from prototype_common_utils import (
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
......
...@@ -11,9 +11,9 @@ import PIL.Image ...@@ -11,9 +11,9 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms._functional_pil as F_pil
import torchvision.transforms._functional_tensor as F_t
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional_tensor as F_t
from common_utils import ( from common_utils import (
_assert_approx_equal_tensor_to_pil, _assert_approx_equal_tensor_to_pil,
_assert_equal_tensor_to_pil, _assert_equal_tensor_to_pil,
......
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
import pytest import pytest
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms._functional_tensor as F_t
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from PIL import Image from PIL import Image
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
......
import numbers
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageEnhance, ImageOps
try:
import accimage
except ImportError:
accimage = None
@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return list(img.size)
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
input_mode = img.mode
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, "L")
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert("RGB")
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[List[int], int],
interpolation: int = Image.BILINEAR,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")
return img.resize(tuple(size[::-1]), interpolation)
@torch.jit.unused
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_channels = get_image_num_channels(img)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill)
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill}
@torch.jit.unused
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = Image.BICUBIC,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if num_output_channels == 1:
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img: Tensor) -> None:
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 255
elif dtype == torch.int8:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
if img.shape[-3] == 3:
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
img = convert_image_dtype(img, torch.float32)
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
return convert_image_dtype(img_hue_adj, orig_dtype)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occurring, so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
p = _parse_pad_padding(padding)
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporarily cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
if padding_mode in ("reflect", "replicate"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
# We manually set it to False to avoid an error downstream in interpolate()
# This behaviour is documented: the parameter is irrelevant for modes
# that are not bilinear or bicubic. We used to raise an error here, but
# now we don't as True is the default.
antialias = False
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1]) # w, h
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
return _max_value(img.dtype) - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
identity_grid = _create_identity_grid(size)
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
...@@ -15,7 +15,7 @@ except ImportError: ...@@ -15,7 +15,7 @@ except ImportError:
accimage = None accimage = None
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from . import functional_pil as F_pil, functional_tensor as F_t from . import _functional_pil as F_pil, _functional_tensor as F_t
class InterpolationMode(Enum): class InterpolationMode(Enum):
......
import numbers import warnings
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np from torchvision.transforms._functional_pil import * # noqa
import torch
from PIL import Image, ImageEnhance, ImageOps
try: warnings.warn(
import accimage "The torchvision.transforms.functional_pil module is deprecated "
except ImportError: "in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
accimage = None "You probably just need to use APIs in "
"torchvision.transforms.functional or in "
"torchvision.transforms.v2.functional."
@torch.jit.unused )
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return list(img.size)
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
input_mode = img.mode
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, "L")
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert("RGB")
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[List[int], int],
interpolation: int = Image.BILINEAR,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")
return img.resize(tuple(size[::-1]), interpolation)
@torch.jit.unused
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_channels = get_image_num_channels(img)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill)
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill}
@torch.jit.unused
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = Image.BICUBIC,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if num_output_channels == 1:
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)
import warnings import warnings
from typing import List, Optional, Tuple, Union
import torch from torchvision.transforms._functional_tensor import * # noqa
from torch import Tensor
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
warnings.warn(
def _is_tensor_a_torch_image(x: Tensor) -> bool: "The torchvision.transforms.functional_tensor module is deprecated "
return x.ndim >= 2 "in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
"You probably just need to use APIs in "
"torchvision.transforms.functional or in "
def _assert_image_tensor(img: Tensor) -> None: "torchvision.transforms.v2.functional."
if not _is_tensor_a_torch_image(img): )
raise TypeError("Tensor is not a torch image.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 255
elif dtype == torch.int8:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
if img.shape[-3] == 3:
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
img = convert_image_dtype(img, torch.float32)
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
return convert_image_dtype(img_hue_adj, orig_dtype)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occurring, so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
p = _parse_pad_padding(padding)
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporarily cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
if padding_mode in ("reflect", "replicate"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
# We manually set it to False to avoid an error downstream in interpolate()
# This behaviour is documented: the parameter is irrelevant for modes
# that are not bilinear or bicubic. We used to raise an error here, but
# now we don't as True is the default.
antialias = False
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1]) # w, h
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
return _max_value(img.dtype) - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
identity_grid = _create_identity_grid(size)
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import datapoints, transforms as _transforms from torchvision import datapoints, transforms as _transforms
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_spatial_size from torchvision.transforms.v2.functional._meta import get_spatial_size
......
...@@ -4,8 +4,8 @@ import PIL.Image ...@@ -4,8 +4,8 @@ import PIL.Image
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
......
...@@ -8,7 +8,8 @@ import torch ...@@ -8,7 +8,8 @@ import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_check_antialias, _check_antialias,
_compute_resized_output_size as __compute_resized_output_size, _compute_resized_output_size as __compute_resized_output_size,
...@@ -19,7 +20,6 @@ from torchvision.transforms.functional import ( ...@@ -19,7 +20,6 @@ from torchvision.transforms.functional import (
pil_to_tensor, pil_to_tensor,
to_pil_image, to_pil_image,
) )
from torchvision.transforms.functional_tensor import _pad_symmetric
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
......
...@@ -4,8 +4,8 @@ import PIL.Image ...@@ -4,8 +4,8 @@ import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment