Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
...@@ -36,7 +36,7 @@ def distance_box_iou_loss( ...@@ -36,7 +36,7 @@ def distance_box_iou_loss(
Tensor: Loss tensor with the reduction option applied. Tensor: Loss tensor with the reduction option applied.
Reference: Reference:
Zhaohui Zheng et. al: Distance Intersection over Union Loss: Zhaohui Zheng et al.: Distance Intersection over Union Loss:
https://arxiv.org/abs/1911.08287 https://arxiv.org/abs/1911.08287
""" """
...@@ -50,10 +50,17 @@ def distance_box_iou_loss( ...@@ -50,10 +50,17 @@ def distance_box_iou_loss(
loss, _ = _diou_iou_loss(boxes1, boxes2, eps) loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
if reduction == "mean": # Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum": elif reduction == "sum":
loss = loss.sum() loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss return loss
......
...@@ -178,7 +178,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -178,7 +178,7 @@ class FeaturePyramidNetwork(nn.Module):
Returns: Returns:
results (OrderedDict[Tensor]): feature maps after FPN layers. results (OrderedDict[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first. They are ordered from the highest resolution first.
""" """
# unpack OrderedDict into two lists for easier handling # unpack OrderedDict into two lists for easier handling
names = list(x.keys()) names = list(x.keys())
...@@ -206,7 +206,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -206,7 +206,7 @@ class FeaturePyramidNetwork(nn.Module):
class LastLevelMaxPool(ExtraFPNBlock): class LastLevelMaxPool(ExtraFPNBlock):
""" """
Applies a max_pool2d on top of the last feature map Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map
""" """
def forward( def forward(
...@@ -216,7 +216,8 @@ class LastLevelMaxPool(ExtraFPNBlock): ...@@ -216,7 +216,8 @@ class LastLevelMaxPool(ExtraFPNBlock):
names: List[str], names: List[str],
) -> Tuple[List[Tensor], List[str]]: ) -> Tuple[List[Tensor], List[str]]:
names.append("pool") names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0)) # Use max pooling to simulate stride 2 subsampling
x.append(F.max_pool2d(x[-1], kernel_size=1, stride=2, padding=0))
return x, names return x, names
......
...@@ -32,6 +32,7 @@ def sigmoid_focal_loss( ...@@ -32,6 +32,7 @@ def sigmoid_focal_loss(
Loss tensor with the reduction option applied. Loss tensor with the reduction option applied.
""" """
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss) _log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs) p = torch.sigmoid(inputs)
...@@ -43,9 +44,15 @@ def sigmoid_focal_loss( ...@@ -43,9 +44,15 @@ def sigmoid_focal_loss(
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss loss = alpha_t * loss
if reduction == "mean": # Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() loss = loss.mean()
elif reduction == "sum": elif reduction == "sum":
loss = loss.sum() loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss return loss
...@@ -33,7 +33,7 @@ def generalized_box_iou_loss( ...@@ -33,7 +33,7 @@ def generalized_box_iou_loss(
Tensor: Loss tensor with the reduction option applied. Tensor: Loss tensor with the reduction option applied.
Reference: Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union: Hamid Rezatofighi et al.: Generalized Intersection over Union:
A Metric and A Loss for Bounding Box Regression: A Metric and A Loss for Bounding Box Regression:
https://arxiv.org/abs/1902.09630 https://arxiv.org/abs/1902.09630
""" """
...@@ -62,9 +62,15 @@ def generalized_box_iou_loss( ...@@ -62,9 +62,15 @@ def generalized_box_iou_loss(
loss = 1 - miouk loss = 1 - miouk
if reduction == "mean": # Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum": elif reduction == "sum":
loss = loss.sum() loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss return loss
...@@ -131,10 +131,10 @@ class Conv2dNormActivation(ConvNormActivation): ...@@ -131,10 +131,10 @@ class Conv2dNormActivation(ConvNormActivation):
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3 kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1 stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1 dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
...@@ -181,10 +181,10 @@ class Conv3dNormActivation(ConvNormActivation): ...@@ -181,10 +181,10 @@ class Conv3dNormActivation(ConvNormActivation):
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3 kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1 stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1 dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
...@@ -266,9 +266,10 @@ class MLP(torch.nn.Sequential): ...@@ -266,9 +266,10 @@ class MLP(torch.nn.Sequential):
Args: Args:
in_channels (int): Number of channels of the input in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None`` norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
bias (bool): Whether to use bias in the linear layer. Default ``True`` bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0 dropout (float): The probability for the dropout layer. Default: 0.0
""" """
...@@ -279,7 +280,7 @@ class MLP(torch.nn.Sequential): ...@@ -279,7 +280,7 @@ class MLP(torch.nn.Sequential):
hidden_channels: List[int], hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None, norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True, inplace: Optional[bool] = None,
bias: bool = True, bias: bool = True,
dropout: float = 0.0, dropout: float = 0.0,
): ):
......
...@@ -160,8 +160,8 @@ def _multiscale_roi_align( ...@@ -160,8 +160,8 @@ def _multiscale_roi_align(
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``. reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
output_size (Union[List[Tuple[int, int]], List[int]]): size of the output output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
sampling_ratio (int): sampling ratio for ROIAlign sampling_ratio (int): sampling ratio for ROIAlign
scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None. scales (Optional[List[float]]): If None, scales will be automatically inferred. Default value is None.
mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None. mapper (Optional[LevelMapper]): If none, mapper will be automatically inferred. Default value is None.
Returns: Returns:
result (Tensor) result (Tensor)
""" """
......
from typing import List, Union from typing import List, Union
import torch import torch
import torch._dynamo
import torch.fx import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops, _has_ops
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
_, channels, height, width = input.size()
# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx
# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask[:, None, :], y, 0)
x = torch.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor
# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
input = maybe_cast(input)
rois = maybe_cast(rois)
_, _, height, width = input.size()
ph = torch.arange(pooled_height, device=input.device) # [PH]
pw = torch.arange(pooled_width, device=input.device) # [PW]
# input: [N, C, H, W]
# rois: [K, 5]
roi_batch_ind = rois[:, 0].int() # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0) # [K]
roi_height = torch.clamp(roi_height, min=1.0) # [K]
bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
"""
iy, ix = dims(2)
"""
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy = torch.arange(height, device=input.device) # [IY]
ix = torch.arange(width, device=input.device) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
def from_K(t):
return t[:, None, None]
y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
val = torch.where(xmask[:, None, None, None, None, :], val, 0)
output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, torch.Tensor):
output /= count[:, None, None, None]
else:
output /= count
output = output.to(orig_dtype)
return output
@torch.fx.wrap @torch.fx.wrap
def roi_align( def roi_align(
input: Tensor, input: Tensor,
...@@ -54,12 +226,15 @@ def roi_align( ...@@ -54,12 +226,15 @@ def roi_align(
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align) _log_api_usage_once(roi_align)
_assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align( return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
) )
......
import numbers
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageEnhance, ImageOps
try:
import accimage
except ImportError:
accimage = None
@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return list(img.size)
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
input_mode = img.mode
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, "L")
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert("RGB")
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[List[int], int],
interpolation: int = Image.BILINEAR,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")
return img.resize(tuple(size[::-1]), interpolation)
@torch.jit.unused
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_channels = get_image_num_channels(img)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) == 1:
fill = fill * num_channels
elif len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill) # type: ignore[arg-type]
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill}
@torch.jit.unused
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = Image.BICUBIC,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if num_output_channels == 1:
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img: Tensor) -> None:
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 255
elif dtype == torch.int8:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
if img.shape[-3] == 3:
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
img = convert_image_dtype(img, torch.float32)
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
return convert_image_dtype(img_hue_adj, orig_dtype)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occurring, so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
p = _parse_pad_padding(padding)
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporarily cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
if padding_mode in ("reflect", "replicate"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
# We manually set it to False to avoid an error downstream in interpolate()
# This behaviour is documented: the parameter is irrelevant for modes
# that are not bilinear or bicubic. We used to raise an error here, but
# now we don't as True is the default.
antialias = False
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1]) # w, h
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
return _max_value(img.dtype) - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
identity_grid = _create_identity_grid(size)
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
from PIL import Image
# See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations
# TODO: Remove this file once PIL minimal version is >= 9.1
if hasattr(Image, "Resampling"):
BICUBIC = Image.Resampling.BICUBIC
BILINEAR = Image.Resampling.BILINEAR
LINEAR = Image.Resampling.BILINEAR
NEAREST = Image.Resampling.NEAREST
AFFINE = Image.Transform.AFFINE
FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT
FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM
PERSPECTIVE = Image.Transform.PERSPECTIVE
else:
BICUBIC = Image.BICUBIC
BILINEAR = Image.BILINEAR
NEAREST = Image.NEAREST
LINEAR = Image.LINEAR
AFFINE = Image.AFFINE
FLIP_LEFT_RIGHT = Image.FLIP_LEFT_RIGHT
FLIP_TOP_BOTTOM = Image.FLIP_TOP_BOTTOM
PERSPECTIVE = Image.PERSPECTIVE
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This file is part of the private API. Please do not use directly these classes as they will be modified on This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights. future versions without warning. The classes should be accessed only via the transforms argument of Weights.
""" """
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
...@@ -44,6 +44,7 @@ class ImageClassification(nn.Module): ...@@ -44,6 +44,7 @@ class ImageClassification(nn.Module):
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None: ) -> None:
super().__init__() super().__init__()
self.crop_size = [crop_size] self.crop_size = [crop_size]
...@@ -51,9 +52,10 @@ class ImageClassification(nn.Module): ...@@ -51,9 +52,10 @@ class ImageClassification(nn.Module):
self.mean = list(mean) self.mean = list(mean)
self.std = list(std) self.std = list(std)
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
img = F.resize(img, self.resize_size, interpolation=self.interpolation) img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
img = F.center_crop(img, self.crop_size) img = F.center_crop(img, self.crop_size)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
...@@ -105,7 +107,11 @@ class VideoClassification(nn.Module): ...@@ -105,7 +107,11 @@ class VideoClassification(nn.Module):
N, T, C, H, W = vid.shape N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W) vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) # We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
vid = F.center_crop(vid, self.crop_size) vid = F.center_crop(vid, self.crop_size)
vid = F.convert_image_dtype(vid, torch.float) vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self.mean, std=self.std) vid = F.normalize(vid, mean=self.mean, std=self.std)
...@@ -145,16 +151,18 @@ class SemanticSegmentation(nn.Module): ...@@ -145,16 +151,18 @@ class SemanticSegmentation(nn.Module):
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None: ) -> None:
super().__init__() super().__init__()
self.resize_size = [resize_size] if resize_size is not None else None self.resize_size = [resize_size] if resize_size is not None else None
self.mean = list(mean) self.mean = list(mean)
self.std = list(std) self.std = list(std)
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
if isinstance(self.resize_size, list): if isinstance(self.resize_size, list):
img = F.resize(img, self.resize_size, interpolation=self.interpolation) img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
......
...@@ -151,7 +151,7 @@ class ToTensorVideo: ...@@ -151,7 +151,7 @@ class ToTensorVideo:
class RandomHorizontalFlipVideo: class RandomHorizontalFlipVideo:
""" """
Flip the video clip along the horizonal direction with a given probability Flip the video clip along the horizontal direction with a given probability
Args: Args:
p (float): probability of the clip being flipped. Default value is 0.5 p (float): probability of the clip being flipped. Default value is 0.5
""" """
......
...@@ -15,15 +15,17 @@ except ImportError: ...@@ -15,15 +15,17 @@ except ImportError:
accimage = None accimage = None
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from . import functional_pil as F_pil, functional_tensor as F_t from . import _functional_pil as F_pil, _functional_tensor as F_t
class InterpolationMode(Enum): class InterpolationMode(Enum):
"""Interpolation modes """Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
and ``lanczos``.
""" """
NEAREST = "nearest" NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
BILINEAR = "bilinear" BILINEAR = "bilinear"
BICUBIC = "bicubic" BICUBIC = "bicubic"
# For PIL compatibility # For PIL compatibility
...@@ -50,6 +52,7 @@ pil_modes_mapping = { ...@@ -50,6 +52,7 @@ pil_modes_mapping = {
InterpolationMode.NEAREST: 0, InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2, InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3, InterpolationMode.BICUBIC: 3,
InterpolationMode.NEAREST_EXACT: 0,
InterpolationMode.BOX: 4, InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5, InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1, InterpolationMode.LANCZOS: 1,
...@@ -164,7 +167,7 @@ def to_tensor(pic) -> Tensor: ...@@ -164,7 +167,7 @@ def to_tensor(pic) -> Tensor:
if pic.mode == "1": if pic.mode == "1":
img = 255 * img img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
# put it from HWC to CHW format # put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous() img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor): if isinstance(img, torch.ByteTensor):
...@@ -202,7 +205,7 @@ def pil_to_tensor(pic: Any) -> Tensor: ...@@ -202,7 +205,7 @@ def pil_to_tensor(pic: Any) -> Tensor:
# handle PIL Image # handle PIL Image
img = torch.as_tensor(np.array(pic, copy=True)) img = torch.as_tensor(np.array(pic, copy=True))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
# put it from HWC to CHW format # put it from HWC to CHW format
img = img.permute((2, 0, 1)) img = img.permute((2, 0, 1))
return img return img
...@@ -390,7 +393,7 @@ def resize( ...@@ -390,7 +393,7 @@ def resize(
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[Union[str, bool]] = "warn",
) -> Tensor: ) -> Tensor:
r"""Resize the input image to the given size. r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -416,37 +419,48 @@ def resize( ...@@ -416,37 +419,48 @@ def resize(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, supported.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater the resized image. If the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then than ``max_size`` after being resized according to ``size``,
the image is resized again so that the longer edge is equal to ``size`` will be overruled so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the ``max_size``.
smaller edge may be shorter than ``size``. This is only supported As a result, the smaller edge may be shorter than ``size``. This
if ``size`` is an int (or a sequence of length 1 in torchscript is only supported if ``size`` is an int (or a sequence of length
mode). 1 in torchscript mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): Whether to apply antialiasing.
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for It only affects **tensors** with bilinear or bicubic modes and it is
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. ignored otherwise: on PIL images, antialiasing is always applied on
This can help making the output for PIL images and tensors closer. bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
Returns: Returns:
PIL Image or Tensor: Resized image. PIL Image or Tensor: Resized image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resize) _log_api_usage_once(resize)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
if not isinstance(interpolation, InterpolationMode): raise TypeError(
raise TypeError("Argument interpolation should be a InterpolationMode") "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if isinstance(size, (list, tuple)): if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]: if len(size) not in [1, 2]:
...@@ -464,11 +478,13 @@ def resize( ...@@ -464,11 +478,13 @@ def resize(
size = [size] size = [size]
output_size = _compute_resized_output_size((image_height, image_width), size, max_size) output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
if (image_height, image_width) == output_size: if [image_height, image_width] == output_size:
return img return img
antialias = _check_antialias(img, antialias, interpolation)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias: if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
...@@ -599,7 +615,7 @@ def resized_crop( ...@@ -599,7 +615,7 @@ def resized_crop(
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[Union[str, bool]] = "warn",
) -> Tensor: ) -> Tensor:
"""Crop the given image and resize it to desired size. """Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -617,13 +633,27 @@ def resized_crop( ...@@ -617,13 +633,27 @@ def resized_crop(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, supported.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): Whether to apply antialiasing.
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for It only affects **tensors** with bilinear or bicubic modes and it is
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. ignored otherwise: on PIL images, antialiasing is always applied on
This can help making the output for PIL images and tensors closer. bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
Returns: Returns:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
...@@ -702,8 +732,7 @@ def perspective( ...@@ -702,8 +732,7 @@ def perspective(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. If given a number, the value is used for all bands respectively.
...@@ -719,16 +748,12 @@ def perspective( ...@@ -719,16 +748,12 @@ def perspective(
coeffs = _get_perspective_coeffs(startpoints, endpoints) coeffs = _get_perspective_coeffs(startpoints, endpoints)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
if not isinstance(interpolation, InterpolationMode): raise TypeError(
raise TypeError("Argument interpolation should be a InterpolationMode") "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
...@@ -802,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten ...@@ -802,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
return tl, tr, bl, br, center return tl, tr, bl, br, center
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: def ten_crop(
img: Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Generate ten cropped images from the given image. """Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
...@@ -854,7 +881,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: ...@@ -854,7 +881,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non-negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2. original image while 2 increases the brightness by a factor of 2.
Returns: Returns:
...@@ -876,7 +903,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -876,7 +903,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non-negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
Returns: Returns:
...@@ -999,7 +1026,7 @@ def _get_inverse_affine_matrix( ...@@ -999,7 +1026,7 @@ def _get_inverse_affine_matrix(
# RotateScaleShear(a, s, (sx, sy)) = # RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx) # = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ] # [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
...@@ -1062,8 +1089,7 @@ def rotate( ...@@ -1062,8 +1089,7 @@ def rotate(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image. If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
...@@ -1085,13 +1111,12 @@ def rotate( ...@@ -1085,13 +1111,12 @@ def rotate(
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rotate) _log_api_usage_once(rotate)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(angle, (int, float)): if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float") raise TypeError("Argument angle should be int or float")
...@@ -1099,9 +1124,6 @@ def rotate( ...@@ -1099,9 +1124,6 @@ def rotate(
if center is not None and not isinstance(center, (list, tuple)): if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence") raise TypeError("Argument center should be a sequence")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
...@@ -1138,13 +1160,12 @@ def affine( ...@@ -1138,13 +1160,12 @@ def affine(
translate (sequence of integers): horizontal and vertical translations (post-rotation translation) translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale scale (float): overall scale
shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
the second value corresponds to a shear parallel to the y axis. the second value corresponds to a shear parallel to the y-axis.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. If given a number, the value is used for all bands respectively.
...@@ -1160,13 +1181,12 @@ def affine( ...@@ -1160,13 +1181,12 @@ def affine(
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(affine) _log_api_usage_once(affine)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if not isinstance(angle, (int, float)): if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float") raise TypeError("Argument angle should be int or float")
...@@ -1183,9 +1203,6 @@ def affine( ...@@ -1183,9 +1203,6 @@ def affine(
if not isinstance(shear, (numbers.Number, (list, tuple))): if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values") raise TypeError("Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int): if isinstance(angle, int):
angle = float(angle) angle = float(angle)
...@@ -1229,6 +1246,9 @@ def affine( ...@@ -1229,6 +1246,9 @@ def affine(
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig.
@torch.jit.unused @torch.jit.unused
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
...@@ -1290,7 +1310,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -1290,7 +1310,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
h (int): Height of the erased region. h (int): Height of the erased region.
w (int): Width of the erased region. w (int): Width of the erased region.
v: Erasing value. v: Erasing value.
inplace(bool, optional): For in-place operations. By default is set False. inplace(bool, optional): For in-place operations. By default, is set False.
Returns: Returns:
Tensor Image: Erased image. Tensor Image: Erased image.
...@@ -1395,7 +1415,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: ...@@ -1395,7 +1415,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to have its colors posterized. img (PIL Image or Tensor): Image to have its colors posterized.
If img is torch Tensor, it should be of type torch.uint8 and If img is torch Tensor, it should be of type torch.uint8, and
it is expected to be in [..., 1 or 3, H, W] format, where ... means it is expected to be in [..., 1 or 3, H, W] format, where ... means
it can have an arbitrary number of leading dimensions. it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
...@@ -1442,7 +1462,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: ...@@ -1442,7 +1462,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2. original image while 2 increases the sharpness by a factor of 2.
Returns: Returns:
...@@ -1527,12 +1547,10 @@ def elastic_transform( ...@@ -1527,12 +1547,10 @@ def elastic_transform(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. Default is ``InterpolationMode.BILINEAR``.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively. If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant. This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(elastic_transform) _log_api_usage_once(elastic_transform)
...@@ -1572,3 +1590,28 @@ def elastic_transform( ...@@ -1572,3 +1590,28 @@ def elastic_transform(
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode) output = to_pil_image(output, mode=img.mode)
return output return output
# TODO in v0.17: remove this helper and change default of antialias to True everywhere
def _check_antialias(
img: Tensor, antialias: Optional[Union[str, bool]], interpolation: InterpolationMode
) -> Optional[bool]:
if isinstance(antialias, str): # it should be "warn", but we don't bother checking against that
if isinstance(img, Tensor) and (
interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC
):
warnings.warn(
"The default value of the antialias parameter of all the resizing transforms "
"(Resize(), RandomResizedCrop(), etc.) "
"will change from None to True in v0.17, "
"in order to be consistent across the PIL and Tensor backends. "
"To suppress this warning, directly pass "
"antialias=True (recommended, future default), antialias=None (current default, "
"which means False for Tensors and True for PIL), "
"or antialias=False (only works on Tensors - PIL will still use antialiasing). "
"This also applies if you are using the inference transforms from the models weights: "
"update the call to weights.transforms(antialias=True)."
)
antialias = None
return antialias
import numbers import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np from torchvision.transforms._functional_pil import * # noqa
import torch
from PIL import Image, ImageEnhance, ImageOps
from typing_extensions import Literal
try: warnings.warn(
import accimage "The torchvision.transforms.functional_pil module is deprecated "
except ImportError: "in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
accimage = None "You probably just need to use APIs in "
from . import _pil_constants "torchvision.transforms.functional or in "
"torchvision.transforms.v2.functional."
)
@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return list(img.size)
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(_pil_constants.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(_pil_constants.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
input_mode = img.mode
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, "L")
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert("RGB")
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[List[int], int],
interpolation: int = _pil_constants.BILINEAR,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")
return img.resize(tuple(size[::-1]), interpolation)
@torch.jit.unused
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_channels = get_image_num_channels(img)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill)
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill}
@torch.jit.unused
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = _pil_constants.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, _pil_constants.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused
def rotate(
img: Image.Image,
angle: float,
interpolation: int = _pil_constants.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = _pil_constants.BICUBIC,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.transform(img.size, _pil_constants.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if num_output_channels == 1:
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)
import warnings import warnings
from typing import List, Optional, Tuple, Union
import torch from torchvision.transforms._functional_tensor import * # noqa
from torch import Tensor
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
warnings.warn(
def _is_tensor_a_torch_image(x: Tensor) -> bool: "The torchvision.transforms.functional_tensor module is deprecated "
return x.ndim >= 2 "in 0.15 and will be **removed in 0.17**. Please don't rely on it. "
"You probably just need to use APIs in "
"torchvision.transforms.functional or in "
def _assert_image_tensor(img: Tensor) -> None: "torchvision.transforms.v2.functional."
if not _is_tensor_a_torch_image(img): )
raise TypeError("Tensor is not a torch image.")
def _assert_threshold(img: Tensor, threshold: float) -> None:
bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 255
elif dtype == torch.int8:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
return 9223372036854775807
else:
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
if img.shape[-3] == 3:
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
p = _parse_pad_padding(padding)
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
if padding_mode in ("reflect", "replicate"):
img = torch_pad(img, p, mode=padding_mode)
else:
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, mask), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
min_vals += torch.tensor((w * 0.5, h * 0.5))
max_vals += torch.tensor((w * 0.5, h * 0.5))
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1]) # w, h
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
_assert_threshold(img, threshold)
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
identity_grid = _create_identity_grid(size)
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
...@@ -3,7 +3,7 @@ import numbers ...@@ -3,7 +3,7 @@ import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -105,7 +105,9 @@ class Compose: ...@@ -105,7 +105,9 @@ class Compose:
class ToTensor: class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. """Convert a PIL Image or ndarray to tensor and scale the values accordingly.
This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
...@@ -139,7 +141,9 @@ class ToTensor: ...@@ -139,7 +141,9 @@ class ToTensor:
class PILToTensor: class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. """Convert a PIL Image to a tensor of the same type - this does not scale values.
This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
""" """
...@@ -166,7 +170,8 @@ class PILToTensor: ...@@ -166,7 +170,8 @@ class PILToTensor:
class ConvertImageDtype(torch.nn.Module): class ConvertImageDtype(torch.nn.Module):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly """Convert a tensor image to the given ``dtype`` and scale the values accordingly.
This function does not support PIL Image. This function does not support PIL Image.
Args: Args:
...@@ -194,19 +199,21 @@ class ConvertImageDtype(torch.nn.Module): ...@@ -194,19 +199,21 @@ class ConvertImageDtype(torch.nn.Module):
class ToPILImage: class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. """Convert a tensor or an ndarray to PIL Image
This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range. H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
Args: Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data: If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``).
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
""" """
...@@ -276,7 +283,7 @@ class Normalize(torch.nn.Module): ...@@ -276,7 +283,7 @@ class Normalize(torch.nn.Module):
class Resize(torch.nn.Module): class Resize(torch.nn.Module):
"""Resize the input image to the given size. """Resize the input image to the given size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means a maximum of two leading dimensions
.. warning:: .. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images The output image might be different depending on its type: when downsampling, the interpolation of PIL images
...@@ -296,25 +303,38 @@ class Resize(torch.nn.Module): ...@@ -296,25 +303,38 @@ class Resize(torch.nn.Module):
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
max_size (int, optional): The maximum allowed for the longer edge of max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater the resized image. If the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then than ``max_size`` after being resized according to ``size``,
the image is resized again so that the longer edge is equal to ``size`` will be overruled so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the ``max_size``.
smaller edge may be shorter than ``size``. This is only supported As a result, the smaller edge may be shorter than ``size``. This
if ``size`` is an int (or a sequence of length 1 in torchscript is only supported if ``size`` is an int (or a sequence of length
mode). 1 in torchscript mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): Whether to apply antialiasing.
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for It only affects **tensors** with bilinear or bicubic modes and it is
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. ignored otherwise: on PIL images, antialiasing is always applied on
This can help making the output for PIL images and tensors closer. bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
""" """
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
if not isinstance(size, (int, Sequence)): if not isinstance(size, (int, Sequence)):
...@@ -324,12 +344,7 @@ class Resize(torch.nn.Module): ...@@ -324,12 +344,7 @@ class Resize(torch.nn.Module):
self.size = size self.size = size
self.max_size = max_size self.max_size = max_size
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
...@@ -752,8 +767,7 @@ class RandomPerspective(torch.nn.Module): ...@@ -752,8 +767,7 @@ class RandomPerspective(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
fill (sequence or number): Pixel fill value for the area outside the transformed fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
""" """
...@@ -763,12 +777,7 @@ class RandomPerspective(torch.nn.Module): ...@@ -763,12 +777,7 @@ class RandomPerspective(torch.nn.Module):
_log_api_usage_once(self) _log_api_usage_once(self)
self.p = p self.p = p
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
...@@ -865,14 +874,27 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -865,14 +874,27 @@ class RandomResizedCrop(torch.nn.Module):
resizing. resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. antialias (bool, optional): Whether to apply antialiasing.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias It only affects **tensors** with bilinear or bicubic modes and it is
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for ignored otherwise: on PIL images, antialiasing is always applied on
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. bilinear or bicubic modes; on other modes (for PIL images and
This can help making the output for PIL images and tensors closer. tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
""" """
def __init__( def __init__(
...@@ -881,7 +903,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -881,7 +903,7 @@ class RandomResizedCrop(torch.nn.Module):
scale=(0.08, 1.0), scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0), ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[Union[str, bool]] = "warn",
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -894,12 +916,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -894,12 +916,7 @@ class RandomResizedCrop(torch.nn.Module):
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
...@@ -967,7 +984,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -967,7 +984,7 @@ class RandomResizedCrop(torch.nn.Module):
format_string = self.__class__.__name__ + f"(size={self.size}" format_string = self.__class__.__name__ + f"(size={self.size}"
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}" format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}" format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
format_string += f", interpolation={interpolate_str})" format_string += f", interpolation={interpolate_str}"
format_string += f", antialias={self.antialias})" format_string += f", antialias={self.antialias})"
return format_string return format_string
...@@ -1039,7 +1056,7 @@ class TenCrop(torch.nn.Module): ...@@ -1039,7 +1056,7 @@ class TenCrop(torch.nn.Module):
Example: Example:
>>> transform = Compose([ >>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images >>> TenCrop(size), # this is a tuple of PIL Images
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ]) >>> ])
>>> #In your test loop you can do the following: >>> #In your test loop you can do the following:
...@@ -1108,6 +1125,11 @@ class LinearTransformation(torch.nn.Module): ...@@ -1108,6 +1125,11 @@ class LinearTransformation(torch.nn.Module):
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
) )
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
...@@ -1135,7 +1157,8 @@ class LinearTransformation(torch.nn.Module): ...@@ -1135,7 +1157,8 @@ class LinearTransformation(torch.nn.Module):
) )
flat_tensor = tensor.view(-1, n) - self.mean_vector flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
tensor = transformed_tensor.view(shape) tensor = transformed_tensor.view(shape)
return tensor return tensor
...@@ -1160,7 +1183,7 @@ class ColorJitter(torch.nn.Module): ...@@ -1160,7 +1183,7 @@ class ColorJitter(torch.nn.Module):
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast. contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation. saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non negative numbers.
...@@ -1172,7 +1195,13 @@ class ColorJitter(torch.nn.Module): ...@@ -1172,7 +1195,13 @@ class ColorJitter(torch.nn.Module):
or use an interpolation that generates negative values before using this function. or use an interpolation that generates negative values before using this function.
""" """
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0,
contrast: Union[float, Tuple[float, float]] = 0,
saturation: Union[float, Tuple[float, float]] = 0,
hue: Union[float, Tuple[float, float]] = 0,
) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
self.brightness = self._check_input(brightness, "brightness") self.brightness = self._check_input(brightness, "brightness")
...@@ -1189,16 +1218,19 @@ class ColorJitter(torch.nn.Module): ...@@ -1189,16 +1218,19 @@ class ColorJitter(torch.nn.Module):
if clip_first_on_zero: if clip_first_on_zero:
value[0] = max(value[0], 0.0) value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2: elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]: value = [float(value[0]), float(value[1])]
raise ValueError(f"{name} values should be between {bound}")
else: else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
# if value is 0 or (1., 1.) for brightness/contrast/saturation # if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing # or (0., 0.) for hue, do nothing
if value[0] == value[1] == center: if value[0] == value[1] == center:
value = None return None
return value else:
return tuple(value)
@staticmethod @staticmethod
def get_params( def get_params(
...@@ -1279,8 +1311,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1279,8 +1311,7 @@ class RandomRotation(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image. If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
...@@ -1298,12 +1329,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1298,12 +1329,7 @@ class RandomRotation(torch.nn.Module):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
...@@ -1381,16 +1407,15 @@ class RandomAffine(torch.nn.Module): ...@@ -1381,16 +1407,15 @@ class RandomAffine(torch.nn.Module):
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default. randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or number, optional): Range of degrees to select from. shear (sequence or number, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default. Will not apply shear by default.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
fill (sequence or number): Pixel fill value for the area outside the transformed fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
...@@ -1413,12 +1438,7 @@ class RandomAffine(torch.nn.Module): ...@@ -1413,12 +1438,7 @@ class RandomAffine(torch.nn.Module):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
...@@ -1602,7 +1622,7 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1602,7 +1622,7 @@ class RandomGrayscale(torch.nn.Module):
class RandomErasing(torch.nn.Module): class RandomErasing(torch.nn.Module):
"""Randomly selects a rectangle region in an torch Tensor image and erases its pixels. """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
This transform does not support PIL Image. This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
...@@ -1707,11 +1727,11 @@ class RandomErasing(torch.nn.Module): ...@@ -1707,11 +1727,11 @@ class RandomErasing(torch.nn.Module):
# cast self.value to script acceptable type # cast self.value to script acceptable type
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [float(self.value)]
elif isinstance(self.value, str): elif isinstance(self.value, str):
value = None value = None
elif isinstance(self.value, tuple): elif isinstance(self.value, (list, tuple)):
value = list(self.value) value = [float(v) for v in self.value]
else: else:
value = self.value value = self.value
...@@ -1938,7 +1958,7 @@ class RandomAdjustSharpness(torch.nn.Module): ...@@ -1938,7 +1958,7 @@ class RandomAdjustSharpness(torch.nn.Module):
Args: Args:
sharpness_factor (float): How much to adjust the sharpness. Can be sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2. original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being sharpened. Default value is 0.5 p (float): probability of the image being sharpened. Default value is 0.5
""" """
...@@ -2045,7 +2065,7 @@ class ElasticTransform(torch.nn.Module): ...@@ -2045,7 +2065,7 @@ class ElasticTransform(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
...@@ -2086,17 +2106,16 @@ class ElasticTransform(torch.nn.Module): ...@@ -2086,17 +2106,16 @@ class ElasticTransform(torch.nn.Module):
self.sigma = sigma self.sigma = sigma
# Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
if not isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
raise TypeError(f"fill should be int or float. Got {type(fill)}") fill = [float(fill)]
elif isinstance(fill, (list, tuple)):
fill = [float(f) for f in fill]
else:
raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
self.fill = fill self.fill = fill
@staticmethod @staticmethod
...@@ -2123,7 +2142,7 @@ class ElasticTransform(torch.nn.Module): ...@@ -2123,7 +2142,7 @@ class ElasticTransform(torch.nn.Module):
def forward(self, tensor: Tensor) -> Tensor: def forward(self, tensor: Tensor) -> Tensor:
""" """
Args: Args:
img (PIL Image or Tensor): Image to be transformed. tensor (PIL Image or Tensor): Image to be transformed.
Returns: Returns:
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
...@@ -2133,9 +2152,9 @@ class ElasticTransform(torch.nn.Module): ...@@ -2133,9 +2152,9 @@ class ElasticTransform(torch.nn.Module):
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill) return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + "(alpha=" format_string = self.__class__.__name__
format_string += str(self.alpha) + ")" format_string += f"(alpha={self.alpha}"
format_string += ", (sigma=" + str(self.sigma) + ")" format_string += f", sigma={self.sigma}"
format_string += ", interpolation={self.interpolation}" format_string += f", interpolation={self.interpolation}"
format_string += ", fill={self.fill})" format_string += f", fill={self.fill})"
return format_string return format_string
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._augment import CutMix, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomChannelPermutation,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat
from ._misc import (
ConvertImageDtype,
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
SanitizeBoundingBoxes,
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._deprecated import ToTensor # usort: skip
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Tuple
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform):
"""[BETA] Randomly select a rectangle region in the input image or video and erase its pixels.
.. v2betastatus:: RandomErasing transform
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p (float, optional): probability that the random erasing operation will be performed.
scale (tuple of float, optional): range of proportion of erased area against input image.
ratio (tuple of float, optional): range of aspect ratio of erased area.
value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace (bool, optional): boolean to make this transform inplace. Default set to False.
Returns:
Erased input.
Example:
>>> from torchvision.transforms import v2 as transforms
>>>
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["v"] is not None:
inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
return inpt
class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
self.num_classes = num_classes
self._labels_getter = _parse_labels_getter(labels_getter)
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1:
raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
)
params = {
"labels": labels,
"batch_size": labels.shape[0],
**self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
),
}
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
)
if inpt.shape[0] != batch_size:
raise ValueError(
f"The batch size of the image or video does not match the batch size of the labels: "
f"{inpt.shape[0]} != {batch_size}."
)
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes)
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
class MixUp(_BaseMixUpCutMix):
"""[BETA] Apply MixUp to the provided batch of images and labels.
.. v2betastatus:: MixUp transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
return inpt
class CutMix(_BaseMixUpCutMix):
"""[BETA] Apply CutMix to the provided batch of images and labels.
.. v2betastatus:: CutMix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_size(flat_inputs)
r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
return inpt
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
self.fill = fill
self._fill = _setup_fill_arg(fill)
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if isinstance(params["fill"], dict):
raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
return params
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
image_or_videos = []
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
if needs_transform and check_type(
inpt,
(
tv_tensors.Image,
PIL.Image.Image,
is_pure_tensor,
tv_tensors.Video,
),
):
image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not image_or_videos:
raise TypeError("Found no image in the sample.")
if len(image_or_videos) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}."
)
idx, image_or_video = image_or_videos[0]
return (flat_inputs, spec, idx), image_or_video
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: ImageOrVideo,
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)
def _apply_image_or_video_transform(
self,
image: ImageOrVideo,
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], _FillTypeJIT],
) -> ImageOrVideo:
fill_ = _get_fill(fill, type(image))
if transform_id == "Identity":
return image
elif transform_id == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "ShearY":
# magnitude should be arctan(magnitude)
# See above
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "TranslateX":
return F.affine(
image,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "TranslateY":
return F.affine(
image,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "Rotate":
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
elif transform_id == "Brightness":
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
elif transform_id == "Contrast":
return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
elif transform_id == "Sharpness":
return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
elif transform_id == "Equalize":
return F.equalize(image)
elif transform_id == "Invert":
return F.invert(image)
else:
raise ValueError(f"No transform available for {transform_id}")
class AutoAugment(_AutoAugmentBase):
r"""[BETA] AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
.. v2betastatus:: AutoAugment transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy, optional): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.AutoAugment
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
self._policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class RandAugment(_AutoAugmentBase):
r"""[BETA] RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
.. v2betastatus:: RandAugment transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
magnitude (int, optional): Magnitude for all the transformations.
num_magnitude_bins (int, optional): The number of different magnitude values.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase):
r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
.. v2betastatus:: TrivialAugmentWide transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int, optional): The number of different magnitude values.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class AugMix(_AutoAugmentBase):
r"""[BETA] AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
.. v2betastatus:: AugMix transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
severity (int, optional): The severity of base augmentation operators. Default is ``3``.
mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
Default is ``-1``.
alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.AugMix
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
}
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
else: # isinstance(inpt, PIL.Image.Image):
image_or_video = F.pil_to_tensor(orig_image_or_video)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
# augmented image or video.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].reshape([batch_dims[0], -1])
mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
mix = tv_tensors.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_pil_image(mix)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
import collections.abc
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform
from ._utils import query_chw
class Grayscale(Transform):
"""[BETA] Convert images or videos to grayscale.
.. v2betastatus:: Grayscale transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
"""
_v1_transform_cls = _transforms.Grayscale
def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform):
"""[BETA] Randomly convert image or videos to grayscale with a probability of p (default 0.1).
.. v2betastatus:: RandomGrayscale transform
If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions
The output has the same number of channels as the input.
Args:
p (float): probability that image should be converted to grayscale.
"""
_v1_transform_cls = _transforms.RandomGrayscale
def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
class ColorJitter(Transform):
"""[BETA] Randomly change the brightness, contrast, saturation and hue of an image or video.
.. v2betastatus:: ColorJitter transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
"""
_v1_transform_cls = _transforms.ColorJitter
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
contrast: Optional[Union[float, Sequence[float]]] = None,
saturation: Optional[Union[float, Sequence[float]]] = None,
hue: Optional[Union[float, Sequence[float]]] = None,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
def _check_input(
self,
value: Optional[Union[float, Sequence[float]]],
name: str,
center: float = 1.0,
bound: Tuple[float, float] = (0, float("inf")),
clip_first_on_zero: bool = True,
) -> Optional[Tuple[float, float]]:
if value is None:
return None
if isinstance(value, (int, float)):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
value = [float(v) for v in value]
else:
raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
@staticmethod
def _generate_value(left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
return output
class RandomChannelPermutation(Transform):
"""[BETA] Randomly permute the channels of an image or video
.. v2betastatus:: RandomChannelPermutation transform
"""
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
class RandomPhotometricDistort(Transform):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
.. v2betastatus:: RandomPhotometricDistort transform
This transform relies on :class:`~torchvision.transforms.v2.ColorJitter`
under the hood to adjust the contrast, saturation, hue, brightness, and also
randomly permutes channels.
Args:
brightness (tuple of float (min, max), optional): How much to jitter brightness.
brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers.
contrast tuple of float (min, max), optional): How much to jitter contrast.
contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers.
saturation (tuple of float (min, max), optional): How much to jitter saturation.
saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers.
hue (tuple of float (min, max), optional): How much to jitter hue.
hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied.
Default is 0.5.
"""
def __init__(
self,
brightness: Tuple[float, float] = (0.875, 1.125),
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
p: float = 0.5,
):
super().__init__()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None:
inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
return inpt
class RandomEqualize(_RandomApplyTransform):
"""[BETA] Equalize the histogram of the given image or video with a given probability.
.. v2betastatus:: RandomEqualize transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Args:
p (float): probability of the image being equalized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomEqualize
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.equalize, inpt)
class RandomInvert(_RandomApplyTransform):
"""[BETA] Inverts the colors of the given image or video with a given probability.
.. v2betastatus:: RandomInvert transform
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being color inverted. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomInvert
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.invert, inpt)
class RandomPosterize(_RandomApplyTransform):
"""[BETA] Posterize the image or video with a given probability by reducing the
number of bits for each color channel.
.. v2betastatus:: RandomPosterize transform
If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being posterized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomPosterize
def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.posterize, inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform):
"""[BETA] Solarize the image or video with a given probability by inverting all pixel
values above a threshold.
.. v2betastatus:: RandomSolarize transform
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being solarized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomSolarize
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform):
"""[BETA] Autocontrast the pixels of the given image or video with a given probability.
.. v2betastatus:: RandomAutocontrast transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomAutocontrast
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.autocontrast, inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
"""[BETA] Adjust the sharpness of the image or video with a given probability.
.. v2betastatus:: RandomAdjustSharpness transform
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being sharpened. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomAdjustSharpness
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment