Commit 1345fab2 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1263 canceled with stages
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img: Tensor) -> None:
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 255
elif dtype == torch.int8:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
if img.shape[-3] == 3:
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
img = convert_image_dtype(img, torch.float32)
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
return convert_image_dtype(img_hue_adj, orig_dtype)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occurring, so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
p = _parse_pad_padding(padding)
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporarily cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
if padding_mode in ("reflect", "replicate"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
# We manually set it to False to avoid an error downstream in interpolate()
# This behaviour is documented: the parameter is irrelevant for modes
# that are not bilinear or bicubic. We used to raise an error here, but
# now we don't as True is the default.
antialias = False
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1]) # w, h
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
return _max_value(img.dtype) - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
identity_grid = _create_identity_grid(size)
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
import warnings
import torch
warnings.warn(
"The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. "
"Please use the 'torchvision.transforms.functional' module instead."
)
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
return clip.float().permute(3, 0, 1, 2) / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import Optional, Tuple, Union
import torch
from torch import nn, Tensor
from . import functional as F, InterpolationMode
__all__ = [
"ObjectDetection",
"ImageClassification",
"VideoClassification",
"SemanticSegmentation",
"OpticalFlow",
]
class ObjectDetection(nn.Module):
def forward(self, img: Tensor) -> Tensor:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, torch.float)
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[0.0, 1.0]``."
)
class ImageClassification(nn.Module):
def __init__(
self,
*,
crop_size: int,
resize_size: int = 256,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.crop_size = [crop_size]
self.resize_size = [resize_size]
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img: Tensor) -> Tensor:
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
img = F.center_crop(img, self.crop_size)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
class VideoClassification(nn.Module):
def __init__(
self,
*,
crop_size: Tuple[int, int],
resize_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.crop_size = list(crop_size)
self.resize_size = list(resize_size)
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor:
need_squeeze = False
if vid.ndim < 5:
vid = vid.unsqueeze(dim=0)
need_squeeze = True
N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W)
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
vid = F.center_crop(vid, self.crop_size)
vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self.mean, std=self.std)
H, W = self.crop_size
vid = vid.view(N, T, C, H, W)
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
if need_squeeze:
vid = vid.squeeze(dim=0)
return vid
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
)
class SemanticSegmentation(nn.Module):
def __init__(
self,
*,
resize_size: Optional[int],
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.resize_size = [resize_size] if resize_size is not None else None
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img: Tensor) -> Tensor:
if isinstance(self.resize_size, list):
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
class OpticalFlow(nn.Module):
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
if not isinstance(img1, Tensor):
img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)
img1 = F.convert_image_dtype(img1, torch.float)
img2 = F.convert_image_dtype(img2, torch.float)
# map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img1 = img1.contiguous()
img2 = img2.contiguous()
return img1, img2
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[-1.0, 1.0]``."
)
#!/usr/bin/env python3
import numbers
import random
import warnings
from transforms import RandomCrop, RandomResizedCrop
from . import _functional_video as F
__all__ = [
"RandomCropVideo",
"RandomResizedCropVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
]
warnings.warn(
"The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. "
"Please use the 'torchvision.transforms' module instead."
)
class RandomCropVideo(RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return F.crop(clip, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class RandomResizedCropVideo(RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
class CenterCropVideo:
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = (int(crop_size), int(crop_size))
else:
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return F.center_crop(clip, self.crop_size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(crop_size={self.crop_size})"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
"""
return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return F.to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = F.hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
import functools
import numbers
from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
from util import datapoints
from util.datapoints import _FillType, _FillTypeJIT
from transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
return _get_defaultdict(_convert_fill_arg(fill))
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
import random
import copy
from typing import Dict, List
import cv2
import numpy as np
from albumentations import DualTransform
from albumentations.augmentations.crops import functional as FCrops
from albumentations.augmentations.geometric import functional as FGeometric
from albumentations.core.bbox_utils import denormalize_bboxes, normalize_bboxes
class RandomSizeCrop(DualTransform):
def __init__(self, min_size, max_size, always_apply=False, p=1.0):
super().__init__(always_apply, p)
self.min_size = min_size
self.max_size = max_size
@property
def targets_as_params(self) -> List[str]:
return ["image"]
def get_params(self):
return {"h_start": random.random(), "w_start": random.random()}
def get_params_dependent_on_targets(self, params):
img_h, img_w = params["image"].shape[:2]
crop_width = random.randint(self.min_size, min(img_w, self.max_size))
crop_height = random.randint(self.min_size, min(img_h, self.max_size))
return {"crop_height": crop_height, "crop_width": crop_width}
def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, **params):
return FCrops.random_crop(img, crop_height, crop_width, h_start, w_start)
def apply_to_bbox(self, bbox, **params):
return FCrops.bbox_random_crop(bbox, **params)
def get_transform_init_args_names(self):
return ("min_size", "max_size")
class RandomShortestSize(DualTransform):
def __init__(
self, min_size, max_size, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0
):
super().__init__(always_apply, p)
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
@property
def targets_as_params(self) -> List[str]:
return ["image"]
def get_params_dependent_on_targets(self, params):
img_h, img_w = params["image"].shape[:2]
min_size = self.min_size[random.randint(0, len(self.min_size) - 1)]
r = min_size / min(img_h, img_w)
if self.max_size is not None:
r = min(r, self.max_size / max(img_h, img_w))
new_width = int(img_w * r)
new_height = int(img_h * r)
return {"height": new_height, "width": new_width}
def apply(self, img, height=0, width=0, interpolation=cv2.INTER_LINEAR, **params):
return FGeometric.resize(img, height=height, width=width, interpolation=interpolation)
def apply_to_bbox(self, bbox, **params):
# Bounding box coordinates are scale invariant
return bbox
def apply_to_keypoint(self, keypoint, **params):
height = params["rows"]
width = params["cols"]
scale_x = params["width"] / width
scale_y = params["height"] / height
return FGeometric.keypoint_scale(keypoint, scale_x, scale_y)
def get_transform_init_args_names(self):
return ("min_size", "max_size", "interpolation")
class CachedMosaic(DualTransform):
def __init__(self, n=4, max_cached_images=40, always_apply=False, p=1.0):
super().__init__(True, p)
assert n == 4, "currenly we only support mosaic_4"
self.n = n
# override always_apply to _always_apply
# it is only used for mosaic_transform not in __call__
# since each data should be applyied to update results_cache
self._always_apply = always_apply
self.results_cache = []
self.max_cached_images = max_cached_images
@property
def targets_as_params(self) -> List[str]:
return ["image", "bboxes"]
def update_results_cache(self, inputs):
self.results_cache.append(copy.deepcopy(inputs))
if len(self.results_cache) > self.max_cached_images:
index = random.randint(0, len(self.results_cache) - 1)
self.results_cache.pop(index)
def get_params_dependent_on_targets(self, params):
self.update_results_cache(params)
# judge whether to apply mosaic transform
apply = (random.random() < self.p) or self._always_apply
apply = apply and len(self.results_cache) > self.n
if not apply:
return {
"image_sizes": None,
"image_size": None,
"coordinates": None,
"extra_images": None,
"extra_bboxes": None,
}
# get images and bboxes, and extra images and bboxes
indices = [random.randint(0, len(self.results_cache) - 1) for _ in range(self.n - 1)]
extra_params = [self.results_cache[i] for i in indices]
params = [params] + extra_params
images, bboxes = list(zip(*map(lambda x: (x["image"], x["bboxes"]), params)))
# get other parameters
image_sizes = [image.shape[:2] for image in images]
image_size = int(sum(max(image_size) for image_size in image_sizes) / len(image_sizes))
center_y, center_x = [
int(random.uniform(0.5 * image_size, 1.5 * image_size)) for _ in range(2)
]
# get transformed coordinates
coordinates = []
for i in range(self.n):
h, w = image_sizes[i]
if i == 0:
# relative to large image
x1a, y1a = max(center_x - w, 0), max(center_y - h, 0)
x2a, y2a = center_x, center_y
# relative to small image
x1b, y1b = w - (x2a - x1a), h - (y2a - y1a)
x2b, y2b = w, h
if i == 1:
x1a, y1a = center_x, max(center_y - h, 0)
x2a, y2a = min(center_x + w, image_size * 2), center_y
x1b, y1b = 0, h - (y2a - y1a)
x2b, y2b = min(w, x2a - x1a), h
if i == 2:
x1a, y1a = max(center_x - w, 0), center_y
x2a, y2a = center_x, min(image_size * 2, center_y + h)
x1b, y1b = w - (x2a - x1a), 0
x2b, y2b = w, min(y2a - y1a, h)
if i == 3:
x1a, y1a = center_x, center_y
x2a, y2a = min(center_x + w, image_size * 2), min(image_size * 2, center_y + h)
x1b, y1b = 0, 0
x2b, y2b = min(w, x2a - x1a), min(y2a - y1a, h)
coordinates.append([x1a, y1a, x2a, y2a, x1b, y1b, x2b, y2b])
return {
"image_sizes": image_sizes,
"image_size": image_size,
"coordinates": coordinates,
"extra_images": images[1:],
"extra_bboxes": bboxes[1:],
}
def apply(self, img, image_size=0, coordinates=None, extra_images=None, **params):
if coordinates is None:
return img
image_final = np.zeros((image_size * 2, image_size * 2, img.shape[-1]), dtype=img.dtype)
for (x1a, y1a, x2a, y2a, x1b, y1b, x2b, y2b), image in zip(
coordinates, [img, *extra_images]
):
image_final[y1a:y2a, x1a:x2a, :] = image[y1b:y2b, x1b:x2b, :]
return FGeometric.resize(image_final, image_size, image_size)
def apply_to_bboxes(
self, bboxes, coordinates=None, extra_bboxes=None, image_size=0, image_sizes=None, **params
):
if coordinates is None:
return bboxes
bboxes_final = []
for (x1a, y1a, x2a, y2a, x1b, y1b, x2b, y2b), bboxes, (rows, cols) in zip(
coordinates,
[bboxes, *extra_bboxes],
image_sizes,
):
bboxes = denormalize_bboxes(bboxes, rows, cols)
valid_flag = [b[0] >= x1b and b[1] >= y1b and b[2] < x2b and b[3] < y2b for b in bboxes]
bboxes = [
(b[0] + x1a - x1b, b[1] + y1a - y1b, b[2] + x1a - x1b, b[3] + y1a - y1b, *b[4:])
for b, vf in zip(bboxes, valid_flag)
if vf == True
]
bboxes = normalize_bboxes(bboxes, image_size * 2, image_size * 2)
bboxes_final.extend(bboxes)
return bboxes_final
class CachedMixup(DualTransform):
def __init__(self, max_cached_images=40, always_apply=False, p=1.0):
super().__init__(True, p)
# override always_apply to _always_apply
# it is only used for mosaic_transform not in __call__
# since each data should be applyied to update results_cache
self._always_apply = always_apply
self.results_cache = []
self.max_cached_images = max_cached_images
def get_params(self):
ratios = [random.betavariate(32.0, 32.0) for _ in range(2)]
ratios = [r / sum(ratios) for r in ratios]
return {"ratios": ratios}
@property
def targets_as_params(self) -> List[str]:
return ["image", "bboxes"]
def update_results_cache(self, inputs):
self.results_cache.append(copy.deepcopy(inputs))
if len(self.results_cache) > self.max_cached_images:
index = random.randint(0, len(self.results_cache) - 1)
self.results_cache.pop(index)
def get_params_dependent_on_targets(self, params):
self.update_results_cache(params)
# judge whether to apply mosaic transform
apply = (random.random() < self.p) or self._always_apply
apply = apply and len(self.results_cache) > 2
if not apply:
return {
"extra_images": None,
"extra_bboxes": None,
}
# get images and bboxes, and extra images and bboxes
index = random.randint(0, len(self.results_cache) - 1)
extra_params = self.results_cache[index]
extra_images, extra_bboxes = extra_params["image"], extra_params["bboxes"]
return {
"extra_images": extra_images,
"extra_bboxes": extra_bboxes,
}
def apply(self, img, extra_images=None, ratios=None, **params):
if extra_images is None:
return img
image_sizes = [im.shape[:2] for im in [img, extra_images]]
image_h, image_w = list(zip(*image_sizes))
image_h, image_w = max(image_h), max(image_w)
image_final = np.zeros((image_h, image_w, img.shape[-1]))
for im, r in zip([img, extra_images], ratios):
image_final[:im.shape[0], :im.shape[1], :] += im * r
image_final /= 2
return image_final.astype(img.dtype)
def apply_to_bboxes(self, bboxes, extra_bboxes=None, **params):
if extra_bboxes is None:
return bboxes
return list(set(bboxes + extra_bboxes))
import warnings
from typing import Any
import torch
from torch import nn
from util import datapoints
class AlbumentationsWrapper(nn.Module):
def __init__(self, albumentation_transforms):
"""
:param albumentation_transforms: albumentations transformation for data augmentation. For example:
"""
super().__init__()
self.albumentation_transforms = albumentation_transforms
def forward(self, input: Any) -> Any:
# get image, box, mask, label from input
labels = input[-1]
not_allowed_data = list(
filter(
lambda x: not isinstance(x, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),
input,
)
)
not_allowed_data_type = set(list(map(lambda x: type(x), not_allowed_data)))
if len(not_allowed_data) != 1:
warnings.warn(
f"current we only support images, bounding boxes and masks"
f"transformation for albumentations, but got {not_allowed_data_type}"
)
images = list(filter(lambda x: isinstance(x, datapoints.Image), input))
boxes = list(filter(lambda x: isinstance(x, datapoints.BoundingBox), input))
masks = list(filter(lambda x: isinstance(x, datapoints.Mask), input))
if len(images) != 1 or len(boxes) != 1:
raise ValueError
# prepare albumentations input format
images = images[0].data.numpy().transpose(1, 2, 0)
boxes = boxes[0].data.numpy()
keep = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1]) # TODO: change into a function
input_dict = {
"image": images,
"bboxes": boxes[keep],
"labels": labels.numpy()[keep],
}
if len(masks) != 0:
masks = masks[0].data.numpy()
if masks.ndim == 3:
masks = masks.transpose(1, 2, 0)[keep]
input_dict.update({"mask": masks})
# perform albumentations transforms
transformed = self.albumentation_transforms(**input_dict)
images, boxes, labels = (
transformed["image"],
transformed["bboxes"],
transformed["labels"],
)
if "mask" in transformed:
masks = transformed["mask"]
if masks.ndim == 3:
masks = masks.transpose(2, 0, 1)
masks = datapoints.Mask(masks)
else:
masks = None
# prepare output data format
images = datapoints.Image(images.transpose(2, 0, 1))
boxes = datapoints.BoundingBox(
torch.as_tensor(boxes).reshape(-1, 4), # in case of empty boxes after transforms
dtype=torch.float,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=images.shape[-2:],
)
output = [images, boxes]
if masks is not None:
output.append(masks)
labels = torch.as_tensor(labels, dtype=torch.long)
output.append(labels)
return tuple(output)
def __str__(self):
return str(self.albumentation_transforms)
import math
from enum import Enum
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
def _apply_op(
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
):
if op_name == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill,
center=[0, 0],
)
elif op_name == "ShearY":
# magnitude should be arctan(magnitude)
# See above
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill,
center=[0, 0],
)
elif op_name == "TranslateX":
img = F.affine(
img,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "TranslateY":
img = F.affine(
img,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
elif op_name == "Identity":
pass
else:
raise ValueError(f"The provided operator {op_name} is not recognized.")
return img
class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}
@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation
Returns:
params required by the autoaugment transformation
"""
policy_id = int(torch.randint(transform_num, (1,)).item())
probs = torch.rand((2,))
signs = torch.randint(2, (2,))
return policy_id, probs, signs
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.policies))
op_meta = self._augmentation_space(10, (height, width))
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
class RandAugment(torch.nn.Module):
r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
magnitude (int): Magnitude for all the transformations.
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
for _ in range(self.num_ops):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
return img
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_ops={self.num_ops}"
f", magnitude={self.magnitude}"
f", num_magnitude_bins={self.num_magnitude_bins}"
f", interpolation={self.interpolation}"
f", fill={self.fill}"
f")"
)
return s
class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = (
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_magnitude_bins={self.num_magnitude_bins}"
f", interpolation={self.interpolation}"
f", fill={self.fill}"
f")"
)
return s
class AugMix(torch.nn.Module):
r"""AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
severity (int): The severity of base augmentation operators. Default is ``3``.
mixture_width (int): The number of augmentation chains. Default is ``3``.
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
Default is ``-1``.
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
s = {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
if self.all_ops:
s.update(
{
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
}
)
return s
@torch.jit.unused
def _pil_to_tensor(self, img) -> Tensor:
return F.pil_to_tensor(img)
@torch.jit.unused
def _tensor_to_pil(self, img: Tensor):
return F.to_pil_image(img)
def _sample_dirichlet(self, params: Tensor) -> Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, orig_img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(orig_img)
if isinstance(orig_img, Tensor):
img = orig_img
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
else:
img = self._pil_to_tensor(orig_img)
op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
orig_dims = list(img.shape)
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
mix = m[:, 0].view(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = (
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=img.dtype)
if not isinstance(orig_img, Tensor):
return self._tensor_to_pil(mix)
return mix
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"severity={self.severity}"
f", mixture_width={self.mixture_width}"
f", chain_depth={self.chain_depth}"
f", alpha={self.alpha}"
f", all_ops={self.all_ops}"
f", interpolation={self.interpolation}"
f", fill={self.fill}"
f")"
)
return s
import PIL.Image
import numpy as np
import torch
from pycocotools import mask as coco_mask
class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks
def __call__(self, image_target_tuple):
image, target = image_target_tuple
if isinstance(image, (torch.Tensor, np.ndarray)):
assert len(image.shape) == 3, "only one image is accepted"
assert image.shape[-3] in [1, 3], "channels of images must be 1 or 3"
_, h, w = image.shape
elif isinstance(image, PIL.Image.Image):
w, h = image.size
else:
raise TypeError(
f"Now only torch.Tensor, PIL.Image.Image and np.ndarray "
f"of an image is accepted but got type {type(image)}"
)
anno = target["annotations"]
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.as_tensor(classes, dtype=torch.int64)
masks = None
if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
# adapt to result_file
scores = None
if anno and "score" in anno[0]:
scores = [obj["score"] for obj in anno]
scores = torch.as_tensor(scores, dtype=torch.float32)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if masks is not None:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
if scores is not None:
scores = scores[keep]
target = {"boxes": boxes, "labels": classes, "image_id": target["image_id"]}
if masks is not None:
target["masks"] = masks
if keypoints is not None:
target["keypoints"] = keypoints
if scores is not None:
target["scores"] = scores
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno], dtype=torch.float32)
iscrowd = torch.tensor(
[obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno], dtype=torch.long
)
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]
return image, target
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
import random
from typing import Any, Dict, List
import torch
from torchvision.ops import box_iou
from transforms.v2 import Transform
from transforms.v2 import functional as F
from transforms.v2.utils import query_bounding_box, query_spatial_size
from util import datapoints
class RandomSizeCrop(Transform):
def __init__(self, min_size: int, max_size: int):
super().__init__()
self.min_size = min_size
self.max_size = max_size
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
crop_h = random.randint(self.min_size, min(orig_h, self.max_size))
crop_w = random.randint(self.min_size, min(orig_w, self.max_size))
# get crop region
top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item()
left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item()
return {"left": left, "top": top, "height": crop_h, "width": crop_w}
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.crop(inpt, **params)
class BoxCenteredRandomSizeCrop(Transform):
def __init__(self, min_size: int, max_size: int, sampler_options=None, trials: int = 40):
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.trials = trials
if sampler_options is None:
sampler_options = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(flat_inputs)
best_iou = 0
for _ in range(self.trials):
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
crop_h = random.randint(self.min_size, min(orig_h, self.max_size))
crop_w = random.randint(self.min_size, min(orig_w, self.max_size))
# get crop region
top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item()
left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item()
right = left + crop_w
bottom = top + crop_h
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
xyxy_bboxes.new_tensor([[left, top, right, bottom]]),
)
cur_region = dict(
top=top,
left=left,
height=crop_h,
width=crop_w,
is_within_crop_area=is_within_crop_area,
)
if ious.max() > best_iou:
best_region = cur_region
if ious.max() < min_jaccard_overlap:
continue
return cur_region
return best_region
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt
output = F.crop(
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if isinstance(output, datapoints.BoundingBox):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBox()
output[~params["is_within_crop_area"]] = 0
return output
import math
import numbers
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from torch import Tensor
try:
import accimage
except ImportError:
accimage = None
from . import _functional_pil as F_pil, _functional_tensor as F_t
class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
and ``lanczos``.
"""
NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
# TODO: Once torchscript supports Enums with staticmethod
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
inverse_modes_mapping = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[i]
pil_modes_mapping = {
InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3,
InterpolationMode.NEAREST_EXACT: 0,
InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1,
}
_is_pil_image = F_pil._is_pil_image
def get_dimensions(img: Tensor) -> List[int]:
"""Returns the dimensions of an image as [channels, height, width].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image dimensions.
"""
if isinstance(img, torch.Tensor):
return F_t.get_dimensions(img)
return F_pil.get_dimensions(img)
def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image size.
"""
if isinstance(img, torch.Tensor):
return F_t.get_image_size(img)
return F_pil.get_image_size(img)
def get_image_num_channels(img: Tensor) -> int:
"""Returns the number of channels of an image.
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
int: The number of channels.
"""
if isinstance(img, torch.Tensor):
return F_t.get_image_num_channels(img)
return F_pil.get_image_num_channels(img)
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
default_float_dtype = torch.get_default_dtype()
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
def pil_to_tensor(pic: Any) -> Tensor:
"""Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript.
See :class:`~torchvision.transforms.PILToTensor` for more details.
.. note::
A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not F_pil._is_pil_image(pic):
raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.as_tensor(nppic)
# handle PIL Image
img = torch.as_tensor(np.array(pic, copy=True))
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
This function does not support PIL Image.
Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output
Returns:
Tensor: Converted image
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image")
return F_t.convert_image_dtype(image, dtype)
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
See :class:`~torchvision.transforms.ToPILImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
# check number of channels
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = "L"
elif npimg.dtype == np.int16:
expected_mode = "I;16"
elif npimg.dtype == np.int32:
expected_mode = "I"
elif npimg.dtype == np.float32:
expected_mode = "F"
if mode is not None and mode != expected_mode:
raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ["LA"]
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "LA"
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGBA"
else:
permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGB"
if mode is None:
raise TypeError(f"Input type {npimg.dtype} is not supported")
return Image.fromarray(npimg, mode=mode)
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
return [new_h, new_w]
def resize(
img: Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e. the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
Returns:
PIL Image or Tensor: Resized image.
"""
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
_, image_height, image_width = get_dimensions(img)
if isinstance(size, int):
size = [size]
output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
if (image_height, image_width) == output_size:
return img
antialias = _check_antialias(img, antialias, interpolation)
if not isinstance(img, torch.Tensor):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
img (PIL Image or Tensor): Image to be padded.
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
PIL Image or Tensor: Padded image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""Crop the given image at specified location and output size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width)
return F_t.crop(img, top, left, height, width)
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
_, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(
img: Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
Returns:
PIL Image or Tensor: Cropped image.
"""
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation, antialias=antialias)
return img
def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)
return F_t.hflip(img)
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
Args:
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
output: List[float] = res.tolist()
return output
def perspective(
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Perform perspective transform of the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): Image to be transformed.
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
Returns:
PIL Image or Tensor: transformed Image.
"""
coeffs = _get_perspective_coeffs(startpoints, endpoints)
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Vertically flipped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)
return F_t.vflip(img)
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop(img, 0, 0, crop_height, crop_width)
tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def ten_crop(
img: Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and
center crop and same for the flipped image.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be
any non-negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image or Tensor: Brightness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)
return F_t.adjust_brightness(img, brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
non-negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image or Tensor: Contrast adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)
return F_t.adjust_contrast(img, contrast_factor)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image or Tensor: Saturation adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)
return F_t.adjust_saturation(img, saturation_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image or Tensor: Hue adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
return F_t.adjust_hue(img, hue_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image or Tensor): PIL Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)
return F_t.adjust_gamma(img, gamma, gain)
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])
cx, cy = center
tx, ty = translate
# RSS without scaling
a = math.cos(rot - sy) / math.cos(sy)
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
c = math.sin(rot - sy) / math.cos(sy)
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
if inverted:
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d, -b, 0.0, -c, a, 0.0]
matrix = [x / scale for x in matrix]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx
matrix[5] += cy
else:
matrix = [a, b, 0.0, c, d, 0.0]
matrix = [x * scale for x in matrix]
# Apply inverse of center translation: RSS * C^-1
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
# Apply translation and center : T * C * RSS * C^-1
matrix[2] += cx + tx
matrix[5] += cy + ty
return matrix
def rotate(
img: Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[int]] = None,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Rotate the image by angle.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to be rotated.
angle (number): rotation angle value in degrees, counter-clockwise.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
Returns:
PIL Image or Tensor: Rotated image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
center_f = [0.0, 0.0]
if center is not None:
_, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
def affine(
img: Tensor,
angle: float,
translate: List[int],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[int]] = None,
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to transform.
angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
the second value corresponds to a shear parallel to the y-axis.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
Returns:
PIL Image or Tensor: Transformed image.
"""
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
_, height, width = get_dimensions(img)
if not isinstance(img, torch.Tensor):
# center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
center_f = [0.0, 0.0]
if center is not None:
_, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig. Anyway, this can be deprecated as we migrate to V2.
@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
This transform does not support torch Tensor.
Args:
img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
Returns:
PIL Image: Grayscale version of the image.
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)
raise TypeError("Input should be PIL Image")
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image or Tensor: Grayscale version of the image.
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)
return F_t.rgb_to_grayscale(img, num_output_channels)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
"""Erase the input Tensor Image with given value.
This transform does not support PIL Image.
Args:
img (Tensor Image): Tensor image of size (C, H, W) to be erased
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool, optional): For in-place operations. By default, is set False.
Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(img)}")
return F_t.erase(img, i, j, h, w, v, inplace=inplace)
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
"""Performs Gaussian blurring on the image by given kernel.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): Image to be blurred
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
like ``(kx, ky)`` or a single integer for square kernels.
.. note::
In torchscript mode kernel_size as single int is not supported, use a sequence of
length 1: ``[ksize, ]``.
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
same sigma in both X/Y directions. If None, then it is computed using
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
Default, None.
.. note::
In torchscript mode sigma as single float is
not supported, use a sequence of length 1: ``[sigma, ]``.
Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
t_img = pil_to_tensor(img)
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode)
return output
def invert(img: Tensor) -> Tensor:
"""Invert the colors of an RGB/grayscale image.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Returns:
PIL Image or Tensor: Color inverted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.invert(img)
return F_t.invert(img)
def posterize(img: Tensor, bits: int) -> Tensor:
"""Posterize an image by reducing the number of bits for each color channel.
Args:
img (PIL Image or Tensor): Image to have its colors posterized.
If img is torch Tensor, it should be of type torch.uint8, and
it is expected to be in [..., 1 or 3, H, W] format, where ... means
it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
bits (int): The number of bits to keep for each channel (0-8).
Returns:
PIL Image or Tensor: Posterized image.
"""
if not (0 <= bits <= 8):
raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
if not isinstance(img, torch.Tensor):
return F_pil.posterize(img, bits)
return F_t.posterize(img, bits)
def solarize(img: Tensor, threshold: float) -> Tensor:
"""Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
threshold (float): All pixels equal or above this value are inverted.
Returns:
PIL Image or Tensor: Solarized image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)
return F_t.solarize(img, threshold)
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
"""Adjust the sharpness of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
Returns:
PIL Image or Tensor: Sharpness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)
return F_t.adjust_sharpness(img, sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
"""Maximize contrast of an image by remapping its
pixels per channel so that the lowest becomes black and the lightest
becomes white.
Args:
img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Returns:
PIL Image or Tensor: An image that was autocontrasted.
"""
if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img)
return F_t.autocontrast(img)
def equalize(img: Tensor) -> Tensor:
"""Equalize the histogram of an image by applying
a non-linear mapping to the input in order to create a uniform
distribution of grayscale values in the output.
Args:
img (PIL Image or Tensor): Image on which equalize is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Returns:
PIL Image or Tensor: An image that was equalized.
"""
if not isinstance(img, torch.Tensor):
return F_pil.equalize(img)
return F_t.equalize(img)
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Transform a tensor image with elastic transformations.
Given alpha and sigma, it will generate displacement
vectors for all pixels based on random offsets. Alpha controls the strength
and sigma controls the smoothness of the displacements.
The displacements are added to an identity grid and the resulting grid is
used to grid_sample from the image.
Applications:
Randomly transforms the morphology of objects in images and produces a
see-through-water-like effect.
Args:
img (PIL Image or Tensor): Image on which elastic_transform is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
"""
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
t_img = pil_to_tensor(img)
shape = t_img.shape
shape = (1,) + shape[-2:] + (2,)
if shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
# TODO: if image shape is [N1, N2, ..., C, H, W] and
# displacement is [1, H, W, 2] we need to reshape input image
# such grid_sampler takes internal code for 4D input
output = F_t.elastic_transform(
t_img,
displacement,
interpolation=interpolation.value,
fill=fill,
)
if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode)
return output
# TODO in v0.17: remove this helper and change default of antialias to True everywhere
def _check_antialias(
img: Tensor, antialias: Optional[Union[str, bool]], interpolation: InterpolationMode
) -> Optional[bool]:
if isinstance(antialias, str): # it should be "warn", but we don't bother checking against that
if isinstance(img, Tensor) and (
interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC
):
warnings.warn(
"The default value of the antialias parameter of all the resizing transforms "
"(Resize(), RandomResizedCrop(), etc.) "
"will change from None to True in v0.17, "
"in order to be consistent across the PIL and Tensor backends. "
"To suppress this warning, directly pass "
"antialias=True (recommended, future default), antialias=None (current default, "
"which means False for Tensors and True for PIL), "
"or antialias=False (only works on Tensors - PIL will still use antialiasing). "
"This also applies if you are using the inference transforms from the models weights: "
"update the call to weights.transforms(antialias=True)."
)
antialias = None
return antialias
import warnings
from transforms._functional_pil import * # noqa
warnings.warn(
"The torchvision.transforms.functional_pil module is deprecated "
"in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
"You probably just need to use APIs in "
"torchvision.transforms.functional or in "
"torchvision.transforms.v2.functional."
)
import warnings
from transforms._functional_tensor import * # noqa
warnings.warn(
"The torchvision.transforms.functional_tensor module is deprecated "
"in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
"You probably just need to use APIs in "
"torchvision.transforms.functional or in "
"torchvision.transforms.v2.functional."
)
import copy
import random
from typing import Any, List, Tuple
import albumentations as A
import numpy as np
import torch
from torch import Tensor, nn
from util import datapoints
from transforms import v2 as T
from util.misc import image_list_from_tensors
class BaseMixTransform(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.dataset = None
self._original_transform = None
self._pre_transform = None
self.p = p
def update_dataset(self, dataset):
self.dataset = dataset
@property
def original_transform(self):
if not self._original_transform:
self._original_transform = copy.deepcopy(self.dataset._transforms)
return self._original_transform
@property
def pre_transform(self):
if not self._pre_transform:
self._pre_transform = self.remove_post_transforms(self.original_transform)
return self._pre_transform
def remove_post_transforms(self, transform):
if isinstance(transform, type(self)):
return None
if isinstance(transform, (T.Compose, A.Compose)):
processed_transforms = []
for trans in transform.transforms:
trans = self.remove_post_transforms(trans)
if not trans:
break
processed_transforms.append(trans)
return type(transform)(processed_transforms)
return transform
@staticmethod
def get_images_boxes_labels_from_input(input: Any,):
# get images, labels and boxes from input
images = list(filter(lambda x: isinstance(x, datapoints.Image), input))
boxes = list(filter(lambda x: isinstance(x, datapoints.BoundingBox), input))
labels = list(
filter(
lambda x: not isinstance(x, (datapoints.Image, datapoints.BoundingBox)),
input,
)
)
if len(labels) != 1 or len(images) != 1 or len(boxes) != 1:
raise ValueError(
f"currently the input must be single datapoints.Image, datapoint.BoundingBox, labels"
)
return images[0].data, boxes[0].data, labels[0]
class MixUp(BaseMixTransform):
def __init__(self, p=0.5):
super().__init__(p=p)
def forward(self, inputs: Any) -> Any:
if random.uniform(0, 1) > self.p:
return inputs
# get images, labels and boxes from input
images, boxes, labels = self.get_images_boxes_labels_from_input(inputs)
# get single extra image
index = random.randint(0, len(self.dataset) - 1)
# a hack implementation for pre_transform
self.dataset._transforms = self.pre_transform
extra_images, extra_boxes, extra_labels = self.dataset.data_augmentation(
*self.dataset.load_image_and_target(index)
)
self.dataset._transforms = self.original_transform
images = [images, extra_images]
boxes = [boxes, extra_boxes]
labels = [labels, extra_labels]
images, boxes, labels = self.mix_transform(images, boxes, labels)
images = datapoints.Image(images)
boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=images.shape[-2:],
)
return images, boxes, labels
@staticmethod
def mix_transform(images: List[Tensor], boxes: List[Tensor], labels: List[Tensor]):
data_type = images[0].dtype
images = image_list_from_tensors(images)
ratios = torch.as_tensor(
data=[np.random.beta(32.0, 32.0) for _ in range(len(images.tensors))],
device=images.tensors.device,
dtype=torch.float32,
)
ratios /= torch.sum(ratios)
image_final = torch.sum(images.tensors * ratios[:, None, None, None], dim=0)
box_final = torch.cat(boxes)
label_final = torch.cat(labels)
return image_final.to(data_type), box_final, label_final
class CachedMixUp(MixUp):
def __init__(self, p=0.5, max_cached_images=40):
super().__init__(p)
self.results_cache = []
self.max_cached_images = max_cached_images
def clone_datapoints(self, datapoint):
if isinstance(datapoint, (List, Tuple)):
return type(datapoint)(self.clone_datapoints(data) for data in datapoint)
if isinstance(datapoint, datapoints.Image):
return datapoints.Image(datapoint.detach().clone().requires_grad_(datapoint.requires_grad))
if isinstance(datapoint, datapoints.BoundingBox):
return datapoints.BoundingBox.wrap_like(
datapoint,
datapoint.detach().clone().requires_grad_(datapoint.requires_grad),
)
if isinstance(datapoint, torch.Tensor):
return datapoint.clone()
def forward(self, inputs: Any) -> Any:
self.results_cache.append(self.clone_datapoints(inputs))
if len(self.results_cache) > self.max_cached_images:
index = random.randint(0, len(self.results_cache) - 1)
self.results_cache.pop(index)
if len(self.results_cache) <= 4:
return inputs
if random.uniform(0, 1) > self.p:
return inputs
# get images, labels and boxes from input
images, boxes, labels = self.get_images_boxes_labels_from_input(inputs)
# get single extra image
index = random.randint(0, len(self.results_cache) - 1)
extra_images, extra_boxes, extra_labels = self.clone_datapoints(self.results_cache[index])
images = [images, extra_images]
boxes = [boxes, extra_boxes]
labels = [labels, extra_labels]
images, boxes, labels = self.mix_transform(images, boxes, labels)
images = datapoints.Image(images)
boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=images.shape[-2:],
)
return images, boxes, labels
class Mosaic(BaseMixTransform):
def __init__(self, p=0.5, n=4):
super().__init__(p=p)
assert n == 4, "Currently only mosaic for n=4 is supported."
self.n = n
def forward(self, inputs: Any) -> Any:
if random.uniform(0, 1) > self.p:
return inputs
# get images, labels and boxes from input
images, boxes, labels = self.get_images_boxes_labels_from_input(inputs)
# get extra images, boxes and labels
self.dataset._transforms = self.pre_transform
indices = self.get_indices()
extra_data_metas = [
self.dataset.data_augmentation(*self.dataset.load_image_and_target(index)) for index in indices
]
extra_images, extra_boxes, extra_labels = list(zip(*extra_data_metas))
self.dataset._transforms = self.original_transform
# concat datas and extra datas to perform mosaic
images = [images, *extra_images]
boxes = [boxes, *extra_boxes]
labels = [labels, *extra_labels]
images, boxes, labels = getattr(self, f"_mosaic{self.n}")(images, boxes, labels)
images = datapoints.Image(images)
boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=images.shape[-2:],
)
return images, boxes, labels
def get_indices(self):
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
@staticmethod
def _mosaic4(images: List[Tensor], boxes: List[Tensor], labels: List[Tensor]):
channel, height, width = images[0].shape
# get average size of the max border as the output image size
image_size = int(sum(max(image.shape[-2:]) for image in images) / len(images))
center_y, center_x = (int(random.uniform(0.5 * image_size, 1.5 * image_size)) for _ in range(2))
image_final = images[0].new_full((channel, image_size * 2, image_size * 2), 0) #
boxes_final = []
labels_final = []
for i in range(4):
c, h, w = images[i].shape
if i == 0:
x1a, y1a, x2a, y2a = (
max(center_x - w, 0),
max(center_y - h, 0),
center_x,
center_y,
) # w.r.t small image
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # w.r.t. large image
elif i == 1: # top right
x1a, y1a, x2a, y2a = (
center_x,
max(center_y - h, 0),
min(center_x + w, image_size * 2),
center_y,
)
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = (
max(center_x - w, 0),
center_y,
center_x,
min(image_size * 2, center_y + h),
)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = (
center_x,
center_y,
min(center_x + w, image_size * 2),
min(image_size * 2, center_y + h),
)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
image_final[:, y1a:y2a, x1a:x2a] = images[i][:, y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
# update boxes and labels
valid_flag = [b[0] >= x1b and b[1] >= y1b and b[2] < x2b and b[3] < y2b for b in boxes[i]]
valid_flag = boxes[i].new_tensor(valid_flag, dtype=torch.bool)
offset = boxes[i].new_tensor([x1a - x1b, y1a - y1b, x1a - x1b, y1a - y1b])
boxes_final.append(boxes[i][valid_flag] + offset)
labels_final.append(labels[i][valid_flag])
image_final = datapoints.Image(image_final)
boxes_final = datapoints.BoundingBox(
torch.cat(boxes_final).reshape(-1, 4), # in case of empty boxes after mosaic
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=image_final.shape[-2:],
)
labels_final = torch.cat(labels_final)
image_final, boxes_final, labels_final = T.Resize(image_size)(image_final, boxes_final, labels_final)
return image_final, boxes_final.data, labels_final
class CachedMosaic(Mosaic):
def __init__(self, p=0.5, n=4, max_cached_images=40):
super().__init__(p, n)
self.results_cache = []
self.max_cached_images = max_cached_images
def get_indices(self):
return [random.randint(0, len(self.results_cache) - 1) for _ in range(self.n - 1)]
def clone_datapoints(self, datapoint):
if isinstance(datapoint, (List, Tuple)):
return type(datapoint)(self.clone_datapoints(data) for data in datapoint)
if isinstance(datapoint, datapoints.Image):
return datapoints.Image(datapoint.detach().clone().requires_grad_(datapoint.requires_grad))
if isinstance(datapoint, datapoints.BoundingBox):
return datapoints.BoundingBox.wrap_like(
datapoint,
datapoint.detach().clone().requires_grad_(datapoint.requires_grad),
)
if isinstance(datapoint, torch.Tensor):
return datapoint.clone()
def forward(self, inputs: Any) -> Any:
self.results_cache.append(self.clone_datapoints(inputs))
if len(self.results_cache) > self.max_cached_images:
index = random.randint(0, len(self.results_cache) - 1)
self.results_cache.pop(index)
if len(self.results_cache) <= 4:
return inputs
if random.uniform(0, 1) > self.p:
return inputs
# get images, labels and boxes from input
images, boxes, labels = self.get_images_boxes_labels_from_input(inputs)
# get extra images, boxes and labels
indices = self.get_indices()
extra_results = [self.clone_datapoints(self.results_cache[index]) for index in indices]
extra_images, extra_boxes, extra_labels = list(
zip(*[self.get_images_boxes_labels_from_input(extra_inputs) for extra_inputs in extra_results])
)
images = [images, *extra_images]
boxes = [boxes, *extra_boxes]
labels = [labels, *extra_labels]
images, boxes, labels = getattr(self, f"_mosaic{self.n}")(images, boxes, labels)
images = datapoints.Image(images)
boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=images.shape[-2:],
)
return images, boxes, labels
import albumentations as A
import cv2
import torch
from transforms import v2 as T
from transforms.album_transform import RandomShortestSize, RandomSizeCrop
from transforms.albumentations_warpper import AlbumentationsWrapper
from transforms.crop import RandomSizeCrop
from transforms.mix_transform import CachedMixUp, CachedMosaic, MixUp, Mosaic
basic = T.Compose([
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# train transform
hflip = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
lsj = T.Compose([
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
T.RandomCrop(size=(1024, 1024), pad_if_needed=True, fill=(123.0, 117.0, 104.0)),
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
lsj_1536 = T.Compose([
T.ScaleJitter(target_size=(1536, 1536), antialias=True),
T.RandomCrop(size=(1536, 1536), pad_if_needed=True, fill=(123.0, 117.0, 104.0)),
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
multiscale = T.Compose([
T.RandomShortestSize(min_size=scales, max_size=1333, antialias=True),
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
detr = T.Compose([
T.RandomHorizontalFlip(),
T.RandomChoice([
T.RandomShortestSize(min_size=scales, max_size=1333, antialias=True),
T.Compose([
T.RandomShortestSize([400, 500, 600], antialias=True),
RandomSizeCrop(384, 600),
T.RandomShortestSize(min_size=scales, max_size=1333, antialias=True),
]),
]),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
ssd = T.Compose([
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=[123.0, 117.0, 104.0]),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
ssdlite = T.Compose([
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
strong_album = T.Compose([
T.RandomHorizontalFlip(),
T.RandomChoice([
T.RandomShortestSize(min_size=scales, max_size=1333, antialias=True),
T.Compose([
T.RandomShortestSize([400, 500, 600], antialias=True),
RandomSizeCrop(384, 600),
T.RandomShortestSize(min_size=scales, max_size=1333, antialias=True),
]),
]),
AlbumentationsWrapper(
A.Compose(
[
A.ShiftScaleRotate(
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
border_mode=cv2.BORDER_CONSTANT,
value=0,
p=0.5,
),
A.RandomBrightnessContrast(
brightness_limit=(0.1, 0.3),
contrast_limit=(0.1, 0.3),
p=0.2,
),
A.OneOf(
[
A.RGBShift(
r_shift_limit=10,
g_shift_limit=10,
b_shift_limit=10,
p=1.0,
),
A.HueSaturationValue(
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=20,
p=1.0,
),
],
p=1.0,
),
A.ImageCompression(quality_lower=85, quality_upper=95, p=0.2),
A.ChannelShuffle(p=0.1),
A.OneOf(
[
A.Blur(blur_limit=3, p=1.0),
A.MedianBlur(blur_limit=3, p=1.0),
],
p=0.1,
),
],
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["labels"], min_visibility=0.0),
)
),
T.RandomHorizontalFlip(p=0.5),
T.RandomVerticalFlip(p=0.5),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
rtdetr_transform = T.Compose([
T.RandomPhotometricDistort(p=0.8),
T.RandomZoomOut(p=0.5, fill=0, side_range=(1.0, 4.0)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=0.5),
T.Resize(size=[640, 640], antialias=True),
T.ToImageTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
# some transform examples related to mosaic, mixup, cached_mosaic and cached_mixup
# you may want to add flip, crop, resize transforms into them for better performance
mosaic = T.Compose([
T.RandomHorizontalFlip(),
Mosaic(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
cached_mosaic = T.Compose([
T.RandomHorizontalFlip(),
CachedMosaic(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
mixup = T.Compose([
T.RandomHorizontalFlip(),
MixUp(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
cached_mixup = T.Compose([
T.RandomHorizontalFlip(),
CachedMixUp(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
mixup_mosaic = T.Compose([
T.RandomHorizontalFlip(),
MixUp(),
Mosaic(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
cached_mixup_mosaic = T.Compose([
T.RandomHorizontalFlip(),
CachedMixUp(),
CachedMosaic(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
mosaic_mixup = T.Compose([
T.RandomHorizontalFlip(),
Mosaic(),
MixUp(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
T.SanitizeBoundingBox(labels_getter=lambda x: x[-1]),
])
cached_mosaic_mixup = T.Compose([
T.RandomHorizontalFlip(),
CachedMosaic(),
CachedMixUp(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
from typing import List, Dict, Tuple
import torch
from torch import Tensor
from transforms import functional as F
from torchvision import ops
def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target
# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(
0, num_masks, (num_masks,), device=paste_image.device
)
random_selection = torch.unique(random_selection).to(torch.long)
paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(
paste_masks, size1, interpolation=F.InterpolationMode.NEAREST
)
# resize bboxes:
ratios = torch.tensor(
(size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device
)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
return image, out_target
class SimpleCopyPaste(torch.nn.Module):
def __init__(
self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR
):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending
def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple))
and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(
isinstance(target[k], torch.Tensor),
f"Value for the key {k} should be a tensor",
)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []
for image, target, paste_image, paste_target in zip(
images, targets, images_rolled, targets_rolled
):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)
return output_images, output_targets
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s
import math
import numbers
import random
import warnings
from collections.abc import Sequence
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
try:
import accimage
except ImportError:
accimage = None
from . import functional as F
from .functional import _interpolation_modes_from_int, InterpolationMode
__all__ = [
"Compose",
"ToTensor",
"PILToTensor",
"ConvertImageDtype",
"ToPILImage",
"Normalize",
"Resize",
"CenterCrop",
"Pad",
"Lambda",
"RandomApply",
"RandomChoice",
"RandomOrder",
"RandomCrop",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"RandomResizedCrop",
"FiveCrop",
"TenCrop",
"LinearTransformation",
"ColorJitter",
"RandomRotation",
"RandomAffine",
"Grayscale",
"RandomGrayscale",
"RandomPerspective",
"RandomErasing",
"GaussianBlur",
"InterpolationMode",
"RandomInvert",
"RandomPosterize",
"RandomSolarize",
"RandomAdjustSharpness",
"RandomAutocontrast",
"RandomEqualize",
"ElasticTransform",
]
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class ToTensor:
"""Convert a PIL Image or ndarray to tensor and scale the values accordingly.
This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
def __init__(self) -> None:
pass
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class PILToTensor:
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
def __init__(self) -> None:
pass
def __call__(self, pic):
"""
.. note::
A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.pil_to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class ConvertImageDtype(torch.nn.Module):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
def forward(self, image):
return F.convert_image_dtype(image, self.dtype)
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image - this does not scale values.
This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
def __init__(self, mode=None):
self.mode = mode
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
PIL Image: Image converted to PIL Image.
"""
return F.to_pil_image(pic, self.mode)
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
if self.mode is not None:
format_string += f"mode={self.mode}"
format_string += ")"
return format_string
class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
class Resize(torch.nn.Module):
"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e. the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError(f"Size should be int or sequence. Got {type(size)}")
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size
self.max_size = max_size
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
def __repr__(self) -> str:
detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
return f"{self.__class__.__name__}{detail}"
class CenterCrop(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class Pad(torch.nn.Module):
"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self, padding, fill=0, padding_mode="constant"):
super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be padded.
Returns:
PIL Image or Tensor: Padded image.
"""
return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
class Lambda:
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
if not callable(lambd):
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
self.lambd = lambd
def __call__(self, img):
return self.lambd(img)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class RandomTransforms:
"""Base class for a list of transformations with randomness
Args:
transforms (sequence): list of transformations
"""
def __init__(self, transforms):
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class RandomApply(torch.nn.Module):
"""Apply randomly a list of transformations with a given probability.
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super().__init__()
self.transforms = transforms
self.p = p
def forward(self, img):
if self.p < torch.rand(1):
return img
for t in self.transforms:
img = t(img)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n p={self.p}"
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order. This transform does not support torchscript."""
def __call__(self, img):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img = self.transforms[i](img)
return img
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument p should be a sequence")
self.p = p
def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)
def __repr__(self) -> str:
return f"{super().__repr__()}(p={self.p})"
class RandomCrop(torch.nn.Module):
"""Crop the given image at a random location.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
padding (int or sequence, optional): Optional padding on each border
of the image. Default is None. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
@staticmethod
def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
_, h, w = F.get_dimensions(img)
th, tw = output_size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__()
self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
_, height, width = F.get_dimensions(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.hflip(img)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.vflip(img)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomPerspective(torch.nn.Module):
"""Performs a random perspective transformation of the given image with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5.
p (float): probability of the image being transformed. Default is 0.5.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__()
self.p = p
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.distortion_scale = distortion_scale
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be Perspectively transformed.
Returns:
PIL Image or Tensor: Randomly transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p:
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img
@staticmethod
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
"""Get parameters for ``perspective`` for a random perspective transform.
Args:
width (int): width of the image.
height (int): height of the image.
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Returns:
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
"""
half_height = height // 2
half_width = width // 2
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomResizedCrop(torch.nn.Module):
"""Crop a random portion of image and resize it to a given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size (int or sequence): expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image or Tensor): Input image.
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
_, height, width = F.get_dimensions(img)
area = height * width
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
def __repr__(self) -> str:
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + f"(size={self.size}"
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
format_string += f", interpolation={interpolate_str}"
format_string += f", antialias={self.antialias})"
return format_string
class FiveCrop(torch.nn.Module):
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an ``int``
instead of sequence like (h, w), a square crop of size (size, size) is made.
If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Example:
>>> transform = Compose([
>>> FiveCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 5 images. Image can be PIL Image or Tensor
"""
return F.five_crop(img, self.size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class TenCrop(torch.nn.Module):
"""Crop the given image into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default).
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal
Example:
>>> transform = Compose([
>>> TenCrop(size), # this is a tuple of PIL Images
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size, vertical_flip=False):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 10 images. Image can be PIL Image or Tensor
"""
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
class LinearTransformation(torch.nn.Module):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.
This transform does not support PIL Image.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.
Applications:
whitening transformation: Suppose X is a column vector zero-centered data.
Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
perform SVD on this matrix and pass it as transformation_matrix.
Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
"""
def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be whitened.
Returns:
Tensor: Transformed image.
"""
shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if tensor.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {tensor.device} vs {self.mean_vector.device}"
)
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
tensor = transformed_tensor.view(shape)
return tensor
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}(transformation_matrix="
f"{self.transformation_matrix.tolist()}"
f", mean_vector={self.mean_vector.tolist()})"
)
return s
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
"""
def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0,
contrast: Union[float, Tuple[float, float]] = 0,
saturation: Union[float, Tuple[float, float]] = 0,
hue: Union[float, Tuple[float, float]] = 0,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
return None
else:
return tuple(value)
@staticmethod
def get_params(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(4)
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
return fn_idx, b, c, s, h
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return img
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"brightness={self.brightness}"
f", contrast={self.contrast}"
f", saturation={self.saturation}"
f", hue={self.hue})"
)
return s
class RandomRotation(torch.nn.Module):
"""Rotate the image by angle.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
fill (sequence or number): Pixel fill value for the area outside the rotated
image. Default is ``0``. If given a number, the value is used for all bands respectively.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
super().__init__()
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
self.interpolation = interpolation
self.expand = expand
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
@staticmethod
def get_params(degrees: List[float]) -> float:
"""Get parameters for ``rotate`` for a random rotation.
Returns:
float: angle parameter to be passed to ``rotate`` for random rotation.
"""
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
return angle
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be rotated.
Returns:
PIL Image or Tensor: Rotated image.
"""
fill = self.fill
channels, _, _ = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
def __repr__(self) -> str:
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
format_string += f", interpolation={interpolate_str}"
format_string += f", expand={self.expand}"
if self.center is not None:
format_string += f", center={self.center}"
if self.fill is not None:
format_string += f", fill={self.fill}"
format_string += ")"
return format_string
class RandomAffine(torch.nn.Module):
"""Random affine transformation of the image keeping center invariant.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or number, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(
self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=InterpolationMode.NEAREST,
fill=0,
center=None,
):
super().__init__()
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = interpolation
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
@staticmethod
def get_params(
degrees: List[float],
translate: Optional[List[float]],
scale_ranges: Optional[List[float]],
shears: Optional[List[float]],
img_size: List[int],
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
"""Get parameters for affine transformation
Returns:
params to be passed to the affine transformation
"""
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
if translate is not None:
max_dx = float(translate[0] * img_size[0])
max_dy = float(translate[1] * img_size[1])
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)
if scale_ranges is not None:
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
else:
scale = 1.0
shear_x = shear_y = 0.0
if shears is not None:
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
if len(shears) == 4:
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
shear = (shear_x, shear_y)
return angle, translations, scale, shear
def forward(self, img):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
img_size = [width, height] # flip for keeping BC on get_params call
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(degrees={self.degrees}"
s += f", translate={self.translate}" if self.translate is not None else ""
s += f", scale={self.scale}" if self.scale is not None else ""
s += f", shear={self.shear}" if self.shear is not None else ""
s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
s += f", fill={self.fill}" if self.fill != 0 else ""
s += f", center={self.center}" if self.center is not None else ""
s += ")"
return s
class Grayscale(torch.nn.Module):
"""Convert image to grayscale.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
Returns:
PIL Image: Grayscale version of the input.
- If ``num_output_channels == 1`` : returned image is single channel
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
"""
def __init__(self, num_output_channels=1):
super().__init__()
self.num_output_channels = num_output_channels
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image or Tensor: Grayscaled image.
"""
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
class RandomGrayscale(torch.nn.Module):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
p (float): probability that image should be converted to grayscale.
Returns:
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
"""
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels, _, _ = F.get_dimensions(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomErasing(torch.nn.Module):
"""Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace. Default set to False.
Returns:
Erased Image.
Example:
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
@staticmethod
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing.
Args:
img (Tensor): Tensor image to be erased.
scale (sequence): range of proportion of erased area against input image.
ratio (sequence): range of aspect ratio of erased area.
value (list, optional): erasing value. If None, it is interpreted as "random"
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
i.e. ``value[0]``.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
area = img_h * img_w
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
return i, j, h, w, v
# Return original image
return 0, 0, img_h, img_w, img
def forward(self, img):
"""
Args:
img (Tensor): Tensor image to be erased.
Returns:
img (Tensor): Erased Tensor image.
"""
if torch.rand(1) < self.p:
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [float(self.value)]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, (list, tuple)):
value = [float(v) for v in self.value]
else:
value = self.value
if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
f"{img.shape[-3]} (number of input channels)"
)
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}"
f"(p={self.p}, "
f"scale={self.scale}, "
f"ratio={self.ratio}, "
f"value={self.value}, "
f"inplace={self.inplace})"
)
return s
class GaussianBlur(torch.nn.Module):
"""Blurs image with randomly chosen Gaussian blur.
If the image is torch Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
kernel_size (int or sequence): Size of the Gaussian kernel.
sigma (float or tuple of float (min, max)): Standard deviation to be used for
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.
Returns:
PIL Image or Tensor: Gaussian blurred version of the input image.
"""
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.")
self.sigma = sigma
@staticmethod
def get_params(sigma_min: float, sigma_max: float) -> float:
"""Choose sigma for random gaussian blurring.
Args:
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
Returns:
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
"""
return torch.empty(1).uniform_(sigma_min, sigma_max).item()
def forward(self, img: Tensor) -> Tensor:
"""
Args:
img (PIL Image or Tensor): image to be blurred.
Returns:
PIL Image or Tensor: Gaussian blurred image
"""
sigma = self.get_params(self.sigma[0], self.sigma[1])
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
return s
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
def _check_sequence_input(x, name, req_sizes):
msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
if not isinstance(x, Sequence):
raise TypeError(f"{name} should be a sequence of length {msg}.")
if len(x) not in req_sizes:
raise ValueError(f"{name} should be a sequence of length {msg}.")
def _setup_angle(x, name, req_sizes=(2,)):
if isinstance(x, numbers.Number):
if x < 0:
raise ValueError(f"If {name} is a single number, it must be positive.")
x = [-x, x]
else:
_check_sequence_input(x, name, req_sizes)
return [float(d) for d in x]
class RandomInvert(torch.nn.Module):
"""Inverts the colors of the given image randomly with a given probability.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be inverted.
Returns:
PIL Image or Tensor: Randomly color inverted image.
"""
if torch.rand(1).item() < self.p:
return F.invert(img)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomPosterize(torch.nn.Module):
"""Posterize the image randomly with a given probability by reducing the
number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being posterized. Default value is 0.5
"""
def __init__(self, bits, p=0.5):
super().__init__()
self.bits = bits
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be posterized.
Returns:
PIL Image or Tensor: Randomly posterized image.
"""
if torch.rand(1).item() < self.p:
return F.posterize(img, self.bits)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
class RandomSolarize(torch.nn.Module):
"""Solarize the image randomly with a given probability by inverting all pixel
values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being solarized. Default value is 0.5
"""
def __init__(self, threshold, p=0.5):
super().__init__()
self.threshold = threshold
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be solarized.
Returns:
PIL Image or Tensor: Randomly solarized image.
"""
if torch.rand(1).item() < self.p:
return F.solarize(img, self.threshold)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
class RandomAdjustSharpness(torch.nn.Module):
"""Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being sharpened. Default value is 0.5
"""
def __init__(self, sharpness_factor, p=0.5):
super().__init__()
self.sharpness_factor = sharpness_factor
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be sharpened.
Returns:
PIL Image or Tensor: Randomly sharpened image.
"""
if torch.rand(1).item() < self.p:
return F.adjust_sharpness(img, self.sharpness_factor)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be autocontrasted.
Returns:
PIL Image or Tensor: Randomly autocontrasted image.
"""
if torch.rand(1).item() < self.p:
return F.autocontrast(img)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class RandomEqualize(torch.nn.Module):
"""Equalize the histogram of the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Args:
p (float): probability of the image being equalized. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be equalized.
Returns:
PIL Image or Tensor: Randomly equalized image.
"""
if torch.rand(1).item() < self.p:
return F.equalize(img)
return img
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class ElasticTransform(torch.nn.Module):
"""Transform a tensor image with elastic transformations.
Given alpha and sigma, it will generate displacement
vectors for all pixels based on random offsets. Alpha controls the strength
and sigma controls the smoothness of the displacements.
The displacements are added to an identity grid and the resulting grid is
used to grid_sample from the image.
Applications:
Randomly transforms the morphology of objects in images and produces a
see-through-water-like effect.
Args:
alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__()
if not isinstance(alpha, (float, Sequence)):
raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
if isinstance(alpha, Sequence) and len(alpha) != 2:
raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
if isinstance(alpha, Sequence):
for element in alpha:
if not isinstance(element, float):
raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
if isinstance(alpha, float):
alpha = [float(alpha), float(alpha)]
if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
alpha = [alpha[0], alpha[0]]
self.alpha = alpha
if not isinstance(sigma, (float, Sequence)):
raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
if isinstance(sigma, Sequence) and len(sigma) != 2:
raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
if isinstance(sigma, Sequence):
for element in sigma:
if not isinstance(element, float):
raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
if isinstance(sigma, float):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
self.sigma = sigma
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
if isinstance(fill, (int, float)):
fill = [float(fill)]
elif isinstance(fill, (list, tuple)):
fill = [float(f) for f in fill]
else:
raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
self.fill = fill
@staticmethod
def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
dx = torch.rand([1, 1] + size) * 2 - 1
if sigma[0] > 0.0:
kx = int(8 * sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], sigma)
dx = dx * alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if sigma[1] > 0.0:
ky = int(8 * sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], sigma)
dy = dy * alpha[1] / size[1]
return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
_, height, width = F.get_dimensions(tensor)
displacement = self.get_params(self.alpha, self.sigma, [height, width])
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
def __repr__(self):
format_string = self.__class__.__name__
format_string += f"(alpha={self.alpha}"
format_string += f", sigma={self.sigma}"
format_string += f", interpolation={self.interpolation}"
format_string += f", fill={self.fill})"
return format_string
from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union, Sequence
import PIL.Image
from util import datapoints
from transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt)
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w
def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(inpt))
for inpt in flat_inputs
if isinstance(
inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox)
)
or is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False
def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import ToTensor # usort: skip
import math
import numbers
import warnings
from typing import Any, Dict, List, Tuple, Union
import PIL.Image
import torch
from util import datapoints
import transforms as _transforms
from transforms.v2 import functional as F
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
class RandomErasing(_RandomApplyTransform):
"""[BETA] Randomly select a rectangle region in the input image or video and erase its pixels.
.. v2betastatus:: RandomErasing transform
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p (float, optional): probability that the random erasing operation will be performed.
scale (tuple of float, optional): range of proportion of erased area against input image.
ratio (tuple of float, optional): range of aspect ratio of erased area.
value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace (bool, optional): boolean to make this transform inplace. Default set to False.
Returns:
Erased input.
Example:
>>> from torchvision.transforms import v2 as transforms
>>>
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment