Unverified Commit 673838f5 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Removing prototype related things from release/0.14 branch (#6687)

* Remove test related to prototype

* Remove torchvision/prototype dir

* Remove references/depth/stereo because it depend on prototype

* Remove prototype related entries on mypy.ini

* Remove things related to prototype in pytest.ini

* clean setup.py from prototype

* Clean CI from prototype

* Remove unused expect file
parent 07ae61bf
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation
from torchvision.utils import _log_api_usage_once
__all__ = (
"RaftStereo",
"raft_stereo_base",
"raft_stereo_realtime",
"Raft_Stereo_Base_Weights",
"Raft_Stereo_Realtime_Weights",
)
class BaseEncoder(raft.FeatureEncoder):
"""Base encoder for FeatureEncoder and ContextEncoder in which weight may be shared.
See the Raft-Stereo paper section 4.6 on backbone part.
"""
def __init__(
self,
*,
block: Callable[..., nn.Module] = ResidualBlock,
layers: Tuple[int, int, int, int] = (64, 64, 96, 128),
strides: Tuple[int, int, int, int] = (2, 1, 2, 2),
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d,
):
# We use layers + (256,) because raft.FeatureEncoder require 5 layers
# but here we will set the last conv layer to identity
super().__init__(block=block, layers=layers + (256,), strides=strides, norm_layer=norm_layer)
# Base encoder don't have the last conv layer of feature encoder
self.conv = nn.Identity()
self.output_dim = layers[3]
num_downsampling = sum([x - 1 for x in strides])
self.downsampling_ratio = 2 ** (num_downsampling)
class FeatureEncoder(nn.Module):
"""Feature Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Context Encoder.
The FeatureEncoder takes concatination of left and right image as input, it produce feature embedding that later
will be used to construct correlation volume.
"""
def __init__(
self,
base_encoder: BaseEncoder,
output_dim: int = 256,
shared_base: bool = False,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
if not shared_base:
self.residual_block: nn.Module = nn.Identity()
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=1)
else:
# If we share base encoder weight for Feature and Context Encoder
# we need to add residual block with InstanceNorm2d and change the kernel size for conv layer
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L35-L37
self.residual_block = block(base_dim, base_dim, norm_layer=nn.InstanceNorm2d, stride=1)
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.base_encoder(x)
x = self.residual_block(x)
x = self.conv(x)
return x
class MultiLevelContextEncoder(nn.Module):
"""Context Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Feature Encoder.
The ContextEncoder takes left image as input and it outputs concatenated hidden_states and contexts.
In Raft-Stereo we have multi level GRUs and this context encoder will also multi outputs (list of Tensor)
that correspond to each GRUs.
Take note that the length of "out_with_blocks" parameter represent the number of GRU's level.
args:
base_encoder (nn.Module): The base encoder part that can have a shared weight with feature_encoder's
base_encoder because they have same architecture.
out_with_blocks (List[bool]): The length represent the number of GRU's level (length of output), and
if the element is True then the output layer on that position will have additional block
output_dim (int): The dimension of output on each level (default: 256)
block (Callable[..., nn.Module]): The type of basic block used for downsampling and output layer
(default: ResidualBlock)
"""
def __init__(
self,
base_encoder: nn.Module,
out_with_blocks: List[bool],
output_dim: int = 256,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.num_level = len(out_with_blocks)
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
self.downsample_and_out_layers = nn.ModuleList(
[
nn.ModuleDict(
{
"downsampler": self._make_downsampler(block, base_dim, base_dim) if i > 0 else nn.Identity(),
"out_hidden_state": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
"out_context": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
}
)
for i in range(self.num_level)
]
)
def _make_out_layer(self, in_channels, out_channels, with_block=True, block=ResidualBlock):
layers = []
if with_block:
layers.append(block(in_channels, in_channels, norm_layer=nn.BatchNorm2d, stride=1))
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
return nn.Sequential(*layers)
def _make_downsampler(self, block, in_channels, out_channels):
block1 = block(in_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=2)
block2 = block(out_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=1)
return nn.Sequential(block1, block2)
def forward(self, x: Tensor) -> List[Tensor]:
x = self.base_encoder(x)
outs = []
for layer_dict in self.downsample_and_out_layers:
x = layer_dict["downsampler"](x)
outs.append(torch.cat([layer_dict["out_hidden_state"](x), layer_dict["out_context"](x)], dim=1))
return outs
class ConvGRU(raft.ConvGRU):
"""Convolutional Gru unit."""
# Modified from raft.ConvGRU to accept pre-convolved contexts,
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/update.py#L23
def forward(self, h: Tensor, x: Tensor, context: List[Tensor]) -> Tensor: # type: ignore[override]
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) + context[0])
r = torch.sigmoid(self.convr(hx) + context[1])
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)) + context[2])
h = (1 - z) * h + z * q
return h
class MultiLevelUpdateBlock(nn.Module):
"""The update block which contains the motion encoder and grus
It must expose a ``hidden_dims`` attribute which is the hidden dimension size of its gru blocks
"""
def __init__(self, *, motion_encoder: MotionEncoder, hidden_dims: List[int]):
super().__init__()
self.motion_encoder = motion_encoder
# The GRU input size is the size of previous level hidden_dim plus next level hidden_dim
# if this is the first gru, then we replace previous level with motion_encoder output channels
# for the last GRU, we dont add the next level hidden_dim
gru_input_dims = []
for i in range(len(hidden_dims)):
input_dim = hidden_dims[i - 1] if i > 0 else motion_encoder.out_channels
if i < len(hidden_dims) - 1:
input_dim += hidden_dims[i + 1]
gru_input_dims.append(input_dim)
self.grus = nn.ModuleList(
[
ConvGRU(input_size=gru_input_dims[i], hidden_size=hidden_dims[i], kernel_size=3, padding=1)
# Ideally we should reverse the direction during forward to use the gru with smallest resolution first
# however currently there is no way to reverse a ModuleList that is jit script compatible
# hence we reverse the ordering of self.grus on the constructor instead
# see: https://github.com/pytorch/pytorch/issues/31772
for i in reversed(list(range(len(hidden_dims))))
]
)
self.hidden_dims = hidden_dims
def forward(
self,
hidden_states: List[Tensor],
contexts: List[List[Tensor]],
corr_features: Tensor,
disparity: Tensor,
level_processed: List[bool],
) -> List[Tensor]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states
# see self.grus on the constructor for more detail
for reverse_i, gru in enumerate(self.grus):
i = len(self.grus) - 1 - reverse_i
if level_processed[i]:
# X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist).
if i == 0:
features = self.motion_encoder(disparity, corr_features)
else:
# 2x downsampled features from larger hidden states
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
if i < len(self.grus) - 1:
# Concat with 2x upsampled features from smaller hidden states
_, _, h, w = hidden_states[i + 1].shape
features = torch.cat(
[
features,
F.interpolate(
hidden_states[i + 1], size=(2 * h, 2 * w), mode="bilinear", align_corners=True
),
],
dim=1,
)
hidden_states[i] = gru(hidden_states[i], features, contexts[i])
# NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward
return hidden_states
class MaskPredictor(raft.MaskPredictor):
"""Mask predictor to be used when upsampling the predicted disparity."""
# We add out_channels compared to raft.MaskPredictor
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
super(raft.MaskPredictor, self).__init__()
self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
self.conv = nn.Conv2d(hidden_size, out_channels, kernel_size=1, padding=0)
self.multiplier = multiplier
class CorrPyramid1d(nn.Module):
"""Row-wise correlation pyramid.
Create a row-wise correlation pyramid with ``num_levels`` level from the outputs of the feature encoder,
this correlation pyramid will later be used as index to create correlation features using CorrBlock1d.
"""
def __init__(self, num_levels: int = 4):
super().__init__()
self.num_levels = num_levels
def forward(self, fmap1: Tensor, fmap2: Tensor) -> List[Tensor]:
"""Build the correlation pyramid from two feature maps.
The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row.
The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
to build the correlation pyramid.
"""
torch._assert(
fmap1.shape == fmap2.shape,
f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)",
)
batch_size, num_channels, h, w = fmap1.shape
fmap1 = fmap1.view(batch_size, num_channels, h, w)
fmap2 = fmap2.view(batch_size, num_channels, h, w)
corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
corr = corr.view(batch_size, h, w, 1, w)
corr_volume = corr / torch.sqrt(torch.tensor(num_channels, device=corr.device))
corr_volume = corr_volume.reshape(batch_size * h * w, 1, 1, w)
corr_pyramid = [corr_volume]
for _ in range(self.num_levels - 1):
corr_volume = F.avg_pool2d(corr_volume, kernel_size=(1, 2), stride=(1, 2))
corr_pyramid.append(corr_volume)
return corr_pyramid
class CorrBlock1d(nn.Module):
"""The row-wise correlation block.
Use indexes from correlation pyramid to create correlation features.
The "indexing" of a given centroid pixel x' is done by concatenating its surrounding row neighbours
within radius
"""
def __init__(self, *, num_levels: int = 4, radius: int = 4):
super().__init__()
self.radius = radius
self.out_channels = num_levels * (2 * radius + 1)
def forward(self, centroids_coords: Tensor, corr_pyramid: List[Tensor]) -> Tensor:
"""Return correlation features by indexing from the pyramid."""
neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, device=centroids_coords.device)
di = di.view(1, 1, neighborhood_side_len, 1).to(centroids_coords.device)
batch_size, _, h, w = centroids_coords.shape # _ = 2 but we only use the first one
# We only consider 1d and take the first dim only
centroids_coords = centroids_coords[:, :1].permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 1)
indexed_pyramid = []
for corr_volume in corr_pyramid:
x0 = centroids_coords + di # end shape is (batch_size * h * w, 1, side_len, 1)
y0 = torch.zeros_like(x0)
sampling_coords = torch.cat([x0, y0], dim=-1)
indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
batch_size, h, w, -1
)
indexed_pyramid.append(indexed_corr_volume)
centroids_coords = centroids_coords / 2
corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
expected_output_shape = (batch_size, self.out_channels, h, w)
torch._assert(
corr_features.shape == expected_output_shape,
f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}",
)
return corr_features
class RaftStereo(nn.Module):
def __init__(
self,
*,
feature_encoder: FeatureEncoder,
context_encoder: MultiLevelContextEncoder,
corr_pyramid: CorrPyramid1d,
corr_block: CorrBlock1d,
update_block: MultiLevelUpdateBlock,
disparity_head: nn.Module,
mask_predictor: Optional[nn.Module] = None,
slow_fast: bool = False,
):
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts:
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
- one part will be used to initialize the hidden state of the of the recurrent unit of
the ``update_block``
corr_pyramid (CorrPyramid1d): Module to buid the correlation pyramid from feature encoder output
corr_block (CorrBlock1d): The correlation block, which uses the correlation pyramid indexes
to create correlation features. It takes the coordinate of the centroid pixel and correlation pyramid
as input and returns the correlation features.
It must expose an ``out_channels`` attribute.
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted disparity. It outputs an updated hidden state
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
on section 3.4 for more detail.
"""
super().__init__()
_log_api_usage_once(self)
# This indicate that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self.output_channels = 1
self.feature_encoder = feature_encoder
self.context_encoder = context_encoder
self.base_downsampling_ratio = feature_encoder.base_downsampling_ratio
self.num_level = self.context_encoder.num_level
self.corr_pyramid = corr_pyramid
self.corr_block = corr_block
self.update_block = update_block
self.disparity_head = disparity_head
self.mask_predictor = mask_predictor
hidden_dims = self.update_block.hidden_dims
# Follow the original implementation to do pre convolution on the context
# See: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L32
self.context_convs = nn.ModuleList(
[nn.Conv2d(hidden_dims[i], hidden_dims[i] * 3, kernel_size=3, padding=1) for i in range(self.num_level)]
)
self.slow_fast = slow_fast
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
) -> List[Tensor]:
"""
Return disparity predictions on every iterations as a list of Tensor.
args:
left_image (Tensor): The input left image with layout B, C, H, W
right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
"""
batch_size, _, h, w = left_image.shape
torch._assert(
(h, w) == right_image.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
)
torch._assert(
(h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0),
f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}",
)
fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
torch._assert(
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
f"The feature encoder should downsample H and W by {self.base_downsampling_ratio}",
)
corr_pyramid = self.corr_pyramid(fmap1, fmap2)
# Multi level contexts
context_outs = self.context_encoder(left_image)
hidden_dims = self.update_block.hidden_dims
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
hidden_states: List[Tensor] = []
contexts: List[List[Tensor]] = []
for i, context_conv in enumerate(self.context_convs):
# As in the original paper, the actual output of the context encoder is split in 2 parts:
# - one part is used to initialize the hidden state of the recurent units of the update block
# - the rest is the "actual" context.
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
hidden_states.append(torch.tanh(hidden_state))
contexts.append(
torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1)
)
_, Cf, Hf, Wf = fmap1.shape
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
# We use flow_init for cascade inference
if flow_init is not None:
coords1 = coords1 + flow_init
disparity_predictions = []
for _ in range(num_iters):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)
disparity = coords1 - coords0
if self.slow_fast:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for i in range(1, self.num_level):
# We only processed the smallest i levels
level_processed = [False] * (self.num_level - i) + [True] * i
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=level_processed
)
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
)
# Take the largest hidden_state to get the disparity
hidden_state = hidden_states[0]
delta_disparity = self.disparity_head(hidden_state)
# in stereo mode, project disparity onto epipolar
delta_disparity[:, 1] = 0.0
coords1 = coords1 + delta_disparity
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_disparity = upsample_flow(
(coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
)
disparity_predictions.append(upsampled_disparity[:, :1])
return disparity_predictions
def _raft_stereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
shared_encoder_weight: bool,
# Feature encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
# Context encoder
context_encoder_layers: Tuple[int, int, int, int, int],
context_encoder_strides: Tuple[int, int, int, int],
# if the `out_with_blocks` param of the context_encoder is True, then
# the particular output on that level position will have additional `context_encoder_block` layer
context_encoder_out_with_blocks: List[bool],
context_encoder_block: Callable[..., nn.Module],
# Correlation block
corr_num_levels: int,
corr_radius: int,
# Motion encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Update block
update_block_hidden_dims: List[int],
# Flow Head
flow_head_hidden_size: int,
# Mask predictor
mask_predictor_hidden_size: int,
use_mask_predictor: bool,
slow_fast: bool,
**kwargs,
):
if len(context_encoder_out_with_blocks) != len(update_block_hidden_dims):
raise ValueError(
"Length of context_encoder_out_with_blocks and update_block_hidden_dims must be the same"
+ "because both of them represent the number of GRUs level"
)
if shared_encoder_weight:
if (
feature_encoder_layers[:-1] != context_encoder_layers[:-1]
or feature_encoder_strides != context_encoder_strides
):
raise ValueError(
"If shared_encoder_weight is True, then the feature_encoder_layers[:-1]"
+ " and feature_encoder_strides must be the same with context_encoder_layers[:-1] and context_encoder_strides!"
)
base_encoder = kwargs.pop("base_encoder", None) or BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_base_encoder = base_encoder
context_base_encoder = base_encoder
else:
feature_base_encoder = BaseEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers[:-1],
strides=feature_encoder_strides,
norm_layer=nn.InstanceNorm2d,
)
context_base_encoder = BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
feature_base_encoder,
output_dim=feature_encoder_layers[-1],
shared_base=shared_encoder_weight,
block=feature_encoder_block,
)
context_encoder = kwargs.pop("context_encoder", None) or MultiLevelContextEncoder(
context_base_encoder,
out_with_blocks=context_encoder_out_with_blocks,
output_dim=context_encoder_layers[-1],
block=context_encoder_block,
)
feature_downsampling_ratio = feature_encoder.base_downsampling_ratio
corr_pyramid = kwargs.pop("corr_pyramid", None) or CorrPyramid1d(num_levels=corr_num_levels)
corr_block = kwargs.pop("corr_block", None) or CorrBlock1d(num_levels=corr_num_levels, radius=corr_radius)
motion_encoder = kwargs.pop("motion_encoder", None) or MotionEncoder(
in_channels_corr=corr_block.out_channels,
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
update_block = kwargs.pop("update_block", None) or MultiLevelUpdateBlock(
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
)
# We use the largest scale hidden_dims of update_block to get the predicted disparity
disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
in_channels=update_block_hidden_dims[0],
hidden_size=flow_head_hidden_size,
)
mask_predictor = kwargs.pop("mask_predictor", None)
if use_mask_predictor:
mask_predictor = MaskPredictor(
in_channels=update_block.hidden_dims[0],
hidden_size=mask_predictor_hidden_size,
out_channels=9 * feature_downsampling_ratio * feature_downsampling_ratio,
)
else:
mask_predictor = None
model = RaftStereo(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_pyramid=corr_pyramid,
corr_block=corr_block,
update_block=update_block,
disparity_head=disparity_head,
mask_predictor=mask_predictor,
slow_fast=slow_fast,
**kwargs, # not really needed, all params should be consumed by now
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
class Raft_Stereo_Realtime_Weights(WeightsEnum):
pass
class Raft_Stereo_Base_Weights(WeightsEnum):
pass
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_realtime(
*, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs
) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
This is the realtime variant of the Raft-Stereo model that is described on the paper section 4.7.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights
:members:
"""
weights = Raft_Stereo_Realtime_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=True,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(2, 1, 2, 2),
context_encoder_out_with_blocks=[True, True],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=True,
**kwargs,
)
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights
:members:
"""
weights = Raft_Stereo_Base_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=False,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(1, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(1, 1, 2, 2),
context_encoder_out_with_blocks=[True, True, False],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=False,
**kwargs,
)
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
FixedSizeCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode
from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(sample)
if isinstance(self.value, (int, float)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return features.OneHotLabel.new_like(inpt, output)
class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, features.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt
class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))
_, H, W = query_chw(sample)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4)
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, features.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt
class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending
self.antialias = antialias
def _copy_paste(
self,
image: features.TensorImageType,
target: Dict[str, Any],
paste_image: features.TensorImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection])
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = cast(List[int], image.shape[-2:])
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2.
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_box(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_box(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
return image, out_target
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.Mask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Masks and Labels or OneHotLabels."
)
targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})
return images, targets
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif features.is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
def forward(self, *inputs: Any) -> Any:
flat_sample, spec = tree_flatten(inputs)
images, targets = self._extract_image_targets(flat_sample)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images, output_targets = [], []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)
output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
antialias=self.antialias,
)
output_images.append(output_image)
output_targets.append(output_target)
# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets)
return tree_unflatten(flat_sample, spec)
import math
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance, _setup_fill_arg
K = TypeVar("K")
V = TypeVar("V")
class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, features.ImageType]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
def _apply_image_transform(
self,
image: features.ImageType,
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillType],
) -> features.ImageType:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
if transform_id == "Identity":
return image
elif transform_id == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "ShearY":
# magnitude should be arctan(magnitude)
# See above
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "TranslateX":
return F.affine(
image,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "TranslateY":
return F.affine(
image,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "Rotate":
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
elif transform_id == "Brightness":
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
elif transform_id == "Contrast":
return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
elif transform_id == "Sharpness":
return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
return F.solarize(image, threshold=magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
elif transform_id == "Equalize":
return F.equalize(image)
elif transform_id == "Invert":
return F.invert(image)
else:
raise ValueError(f"No transform available for {transform_id}")
class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
self._policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
class RandAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
class TrivialAugmentWide(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
class AugMix(_AutoAugmentBase):
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
}
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
_, height, width = get_chw(orig_image)
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = F.pil_to_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image.shape)
batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
mix = m[:, 0].view(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
aug = self._apply_image_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix)
import collections.abc
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform
from ._utils import query_chw
class ColorJitter(Transform):
def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
contrast: Optional[Union[float, Sequence[float]]] = None,
saturation: Optional[Union[float, Sequence[float]]] = None,
hue: Optional[Union[float, Sequence[float]]] = None,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
def _check_input(
self,
value: Optional[Union[float, Sequence[float]]],
name: str,
center: float = 1.0,
bound: Tuple[float, float] = (0, float("inf")),
clip_first_on_zero: bool = True,
) -> Optional[Tuple[float, float]]:
if value is None:
return None
if isinstance(value, float):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
else:
raise TypeError(f"{name} should be a single number or a sequence with length 2.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
def _get_params(self, sample: Any) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = F.adjust_brightness(output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = F.adjust_contrast(output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = F.adjust_saturation(output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = F.adjust_hue(output, hue_factor=hue_factor)
return output
class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_channels, _, _ = query_chw(sample)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"],
(torch.rand(5) < self.p).tolist(),
),
contrast_before=bool(torch.rand(()) < 0.5),
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType:
if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
)
if params["contrast1"] and params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["saturation"]:
inpt = F.adjust_saturation(
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
class RandomEqualize(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
class RandomInvert(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)
class RandomPosterize(_RandomApplyTransform):
def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.posterize(inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform):
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.solarize(inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)
import warnings
from typing import Any, Callable, List, Optional, Sequence
import torch
from torchvision.prototype.transforms import Transform
class Compose(Transform):
def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomApply(Compose):
def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None:
super().__init__(transforms)
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
class RandomChoice(Transform):
def __init__(
self,
transforms: Sequence[Callable],
probabilities: Optional[List[float]] = None,
p: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is not None:
warnings.warn(
"Argument p is deprecated and will be removed in a future release. "
"Please use probabilities argument instead."
)
probabilities = p
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
raise ValueError(
f"The number of probabilities doesn't match the number of transforms: "
f"{len(probabilities)} != {len(transforms)}"
)
super().__init__()
self.transforms = transforms
total = sum(probabilities)
self.probabilities = [prob / total for prob in probabilities]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
sample = transform(sample)
return sample
import warnings
from typing import Any, Dict, Union
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import query_chw
class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
)
super().__init__()
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)
class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
f"is deprecated and will be removed in a future release."
)
if num_output_channels == 1:
replacement_msg = (
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
)
else:
replacement_msg = (
"transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
")"
)
warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}")
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, p: float = 0.1) -> None:
warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
"Instead, please use\n\n"
"transforms.RandomApply(\n"
" transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
" )\n"
" p=...,\n"
")"
)
super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.transforms.functional import _get_perspective_coeffs
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
has_all,
has_any,
query_bounding_box,
query_chw,
)
class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.size = (
[size]
if isinstance(size, int)
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
)
self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
inpt,
self.size,
interpolation=self.interpolation,
max_size=self.max_size,
antialias=self.antialias,
)
class CenterCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]:
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples
# What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image
_, height, width = query_chw(sample)
area = height * width
log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
class FiveCrop(Transform):
"""
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
... images, labels = sample
... batch_size = len(images)
... images = features.Image.new_like(images[0], torch.stack(images))
... labels = features.Label.new_like(labels, labels.repeat(batch_size))
... return images, labels
...
>>> image = features.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0)
>>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
torch.Size([5, 3, 224, 224])
>>> labels.shape
torch.Size([5])
"""
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(
self, inpt: features.ImageType, params: Dict[str, Any]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class TenCrop(Transform):
"""
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_h, orig_w = query_chw(sample)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.expand = expand
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
fill=fill,
center=self.center,
)
class RandomAffine(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
_, height, width = query_chw(sample)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translate = (tx, ty)
else:
translate = (0, 0)
if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=fill,
center=self.center,
)
class RandomCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, padded_height, padded_width = query_chw(sample)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
padded_height += pad_top + pad_bottom
padded_width += pad_left + pad_right
else:
pad_left = pad_right = pad_top = pad_bottom = 0
cropped_height, cropped_width = self.size
if self.pad_if_needed:
if padded_height < cropped_height:
diff = cropped_height - padded_height
pad_top += diff
pad_bottom += diff
padded_height += 2 * diff
if padded_width < cropped_width:
diff = cropped_width - padded_width
pad_left += diff
pad_right += diff
padded_width += 2 * diff
if padded_height < cropped_height or padded_width < cropped_width:
raise ValueError(
f"Required crop size {(cropped_height, cropped_width)} is larger than "
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
)
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)
needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
if padded_width > cropped_width
else (False, 0)
)
return dict(
needs_crop=needs_vert_crop or needs_horz_crop,
top=top,
left=left,
height=cropped_height,
width=cropped_width,
needs_pad=needs_pad,
padding=padding,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
return inpt
class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
_, height, width = query_chw(sample)
distortion_scale = self.distortion_scale
half_height = height // 2
half_width = width // 2
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
return dict(perspective_coeffs=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective(
inpt,
**params,
fill=fill,
interpolation=self.interpolation,
)
class ElasticTransform(Transform):
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
_, *size = query_chw(sample)
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic(
inpt,
**params,
fill=fill,
interpolation=self.interpolation,
)
class RandomIoUCrop(Transform):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_h, orig_w = query_chw(sample)
bboxes = query_bounding_box(sample)
while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()
for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue
# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt
is_within_crop_area = params["is_within_crop_area"]
if isinstance(inpt, (features.Label, features.OneHotLabel)):
return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, features.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
output = features.BoundingBox.new_like(output, bboxes)
elif isinstance(output, features.Mask):
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = features.Mask.new_like(output, masks)
return output
def forward(self, *inputs: Any) -> Any:
if not (
has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
return super().forward(*inputs)
class ScaleJitter(Transform):
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
needs_crop = new_height != height or new_width != width
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
try:
bounding_boxes = query_bounding_box(sample)
except ValueError:
bounding_boxes = None
if needs_crop and bounding_boxes is not None:
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width)
)
bounding_boxes = features.BoundingBox.new_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
),
)
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
needs_pad = pad_bottom != 0 or pad_right != 0
return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = F.crop(
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like(
inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
)
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
return super().forward(*inputs)
class RandomResize(Transform):
def __init__(
self,
min_size: int,
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
from typing import Any, Dict, Optional, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.convert_format_bounding_box(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
class ConvertImageDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type]
class ConvertColorSpace(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
copy: bool = True,
) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace.from_str(color_space)
self.color_space = color_space
if isinstance(old_color_space, str):
old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
self.copy = copy
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,)
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size)
return features.BoundingBox.new_like(inpt, output)
import functools
from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image
import torch
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt
class Lambda(Transform):
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class LinearTransformation(Transform):
_transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
return super().forward(*inputs)
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_tensor = inpt.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
return transformed_tensor.view(shape)
class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
return super().forward(*inpts)
class GaussianBlur(Transform):
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, sample: Any) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
def __init__(self, min_size: float = 1.0) -> None:
super().__init__()
self.min_size = min_size
def _get_params(self, sample: Any) -> Dict[str, Any]:
bounding_box = query_bounding_box(sample)
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
# format,we need to convert first just to afterwards compute the width and height again, although they were
# there in the first place for these formats.
bounding_box = F.convert_format_bounding_box(
bounding_box, old_format=bounding_box.format, new_format=features.BoundingBoxFormat.XYXY
)
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
return dict(valid_indices=valid_indices)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.new_like(inpt, inpt[params["valid_indices"]])
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torch import Tensor
from . import functional as F, InterpolationMode
__all__ = ["StereoMatching"]
class StereoMatching(torch.nn.Module):
def __init__(
self,
*,
use_gray_scale: bool = False,
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
# pacify mypy
self.resize_size: Union[None, List]
if resize_size is not None:
self.resize_size = list(resize_size)
else:
self.resize_size = None
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
def _process_image(img: PIL.Image.Image) -> Tensor:
if self.resize_size is not None:
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
if self.use_gray_scale is True:
img = F.rgb_to_grayscale(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
img = img.contiguous()
return img
left_image = _process_image(left_image)
right_image = _process_image(right_image)
return left_image, right_image
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
import enum
from typing import Any, Callable, Dict, Tuple, Type, Union
import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.utils import _log_api_usage_once
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
features.is_simple_tensor,
features._Feature,
PIL.Image.Image,
)
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def extra_repr(self) -> str:
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
class _RandomApplyTransform(Transform):
def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
from typing import Any, cast, Dict, Optional, Union
import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return cast(features.Image, F.decode_image_with_pil(inpt))
class LabelToOneHot(Transform):
_transformed_types = (features.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt, num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> torch.Tensor:
return F.pil_to_tensor(inpt)
class ToImageTensor(Transform):
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return cast(features.Image, F.to_image_tensor(inpt))
class ToImagePIL(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
import PIL.Image
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[return-value, arg-type]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
bounding_boxes = {item for item in flat_sample if isinstance(item, features.BoundingBox)}
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = {
get_chw(item)
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item)
}
if not chws:
raise TypeError("No image was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop()
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
for obj in flat_sample:
if _isinstance(obj, types_or_checks):
return True
return False
def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
break
else:
return False
return True
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space,
get_dimensions,
get_image_num_channels,
get_num_channels,
get_spatial_size,
) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor
from ._color import (
adjust_brightness,
adjust_brightness_image_pil,
adjust_brightness_image_tensor,
adjust_contrast,
adjust_contrast_image_pil,
adjust_contrast_image_tensor,
adjust_gamma,
adjust_gamma_image_pil,
adjust_gamma_image_tensor,
adjust_hue,
adjust_hue_image_pil,
adjust_hue_image_tensor,
adjust_saturation,
adjust_saturation_image_pil,
adjust_saturation_image_tensor,
adjust_sharpness,
adjust_sharpness_image_pil,
adjust_sharpness_image_tensor,
autocontrast,
autocontrast_image_pil,
autocontrast_image_tensor,
equalize,
equalize_image_pil,
equalize_image_tensor,
invert,
invert_image_pil,
invert_image_tensor,
posterize,
posterize_image_pil,
posterize_image_tensor,
solarize,
solarize_image_pil,
solarize_image_tensor,
)
from ._geometry import (
affine,
affine_bounding_box,
affine_image_pil,
affine_image_tensor,
affine_mask,
center_crop,
center_crop_bounding_box,
center_crop_image_pil,
center_crop_image_tensor,
center_crop_mask,
crop,
crop_bounding_box,
crop_image_pil,
crop_image_tensor,
crop_mask,
elastic,
elastic_bounding_box,
elastic_image_pil,
elastic_image_tensor,
elastic_mask,
elastic_transform,
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
horizontal_flip_image_tensor,
horizontal_flip_mask,
pad,
pad_bounding_box,
pad_image_pil,
pad_image_tensor,
pad_mask,
perspective,
perspective_bounding_box,
perspective_image_pil,
perspective_image_tensor,
perspective_mask,
resize,
resize_bounding_box,
resize_image_pil,
resize_image_tensor,
resize_mask,
resized_crop,
resized_crop_bounding_box,
resized_crop_image_pil,
resized_crop_image_tensor,
resized_crop_mask,
rotate,
rotate_bounding_box,
rotate_image_pil,
rotate_image_tensor,
rotate_mask,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
vertical_flip,
vertical_flip_bounding_box,
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_mask,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
to_image_pil,
to_image_tensor,
to_pil_image,
)
from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase
@torch.jit.unused
def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=image.mode)
def erase(
inpt: features.ImageTypeJIT,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
else:
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
else:
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
else:
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
else:
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
else:
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
else:
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
else:
return posterize_image_pil(inpt, bits=bits)
solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
else:
return solarize_image_pil(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.autocontrast()
else:
return autocontrast_image_pil(inpt)
equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.equalize()
else:
return equalize_image_pil(inpt)
invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.invert()
else:
return invert_image_pil(inpt)
import warnings
from typing import Any, List
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional as _F
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)"
if num_output_channels == 3:
replacement = f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB)"
warnings.warn(
f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
else None
)
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=features.ColorSpace.GRAY"
f"{f', old_color_space=features.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
)
if num_output_channels == 3:
replacement = (
f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB"
f"{f', old_color_space=features.ColorSpace.GRAY' if old_color_space is not None else ''})"
)
warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
)
return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageTypeJIT) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
)
return _F.get_image_size(inpt)
import numbers
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix,
InterpolationMode,
pil_modes_mapping,
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import (
_cast_squeeze_in,
_cast_squeeze_out,
_parse_pad_padding,
interpolate,
)
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(mask)
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.horizontal_flip()
else:
return horizontal_flip_image_pil(inpt)
vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(mask)
def vertical_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.vertical_flip()
else:
return vertical_flip_image_pil(inpt)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip
def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if isinstance(size, int):
size = [size]
return __compute_resized_output_size(image_size, size=size, max_size=max_size)
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]
if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width)
# This is a perf hack to avoid slow channels_last upsample code path
# Related issue: https://github.com/pytorch/pytorch/issues/83840
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
# Below code is copied from _FT.resize
# This is due to the fact that we need to apply the hack on casted image and not before
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])
shape = (image.shape[0], 2, image.shape[2], image.shape[3])
image = image.expand(shape)
image = interpolate(
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
)
image = image[:, 0, ...]
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
else:
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
return image.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def resize_bounding_box(
bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = image_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
(new_height, new_width),
)
def resize(
inpt: features.InputTypeJIT,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
else:
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
def _affine_parse_args(
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
if center is not None:
if not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
else:
center = [float(c) for c in center]
return angle, translate, shear, center
def affine_image_tensor(
image: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if image.numel() == 0:
return image
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
image = image.view(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width))
@torch.jit.unused
def affine_image_pil(
image: PIL.Image.Image,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
_, height, width = get_dimensions_image_pil(image)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _affine_bounding_box_xyxy(
bounding_box: torch.Tensor,
image_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
expand: bool = False,
) -> torch.Tensor:
angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)
if center is None:
height, width = image_size
center = [width * 0.5, height * 0.5]
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
affine_matrix = torch.tensor(
_get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False),
dtype=dtype,
device=device,
).view(2, 3)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, affine_matrix.T)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = image_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, 1.0 * height, 1.0],
[1.0 * width, 1.0 * height, 1.0],
[1.0 * width, 0.0, 1.0],
],
dtype=dtype,
device=device,
)
new_points = torch.matmul(points, affine_matrix.T)
tr, _ = torch.min(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]
return out_bboxes.to(bounding_box.dtype)
def affine_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)
# out_bboxes should be of shape [N boxes, 4]
return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
def affine_mask(
mask: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = affine_image_tensor(
mask,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
def affine(
inpt: features.InputTypeJIT,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> features.InputTypeJIT:
# TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
else:
return affine_image_pil(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def rotate_image_tensor(
image: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
center_f = [0.0, 0.0]
if center is not None:
if expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
else:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0:
image = _FT.rotate(
image.view(-1, num_channels, height, width),
matrix,
interpolation=interpolation.value,
expand=expand,
fill=fill,
)
new_height, new_width = image.shape[-2:]
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
return image.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def rotate_image_pil(
image: PIL.Image.Image,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
return _FP.rotate(
image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
)
def rotate_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
out_bboxes = _affine_bounding_box_xyxy(
bounding_box,
image_size,
angle=-angle,
translate=[0.0, 0.0],
scale=1.0,
shear=[0.0, 0.0],
center=center,
expand=expand,
)
if expand:
# TODO: Move this computation inside of `_affine_bounding_box_xyxy` to avoid computing the rotation and points
# matrix twice
height, width = image_size
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
new_width, new_height = _FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape),
image_size,
)
def rotate_mask(
mask: torch.Tensor,
angle: float,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = rotate_image_tensor(
mask,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def rotate(
inpt: features.InputTypeJIT,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad
def pad_image_tensor(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
if fill is None:
# This is a JIT workaround
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
else:
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: Union[int, float, None],
padding_mode: str = "constant",
) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
if image.numel() > 0:
image = _FT.pad(
img=image.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom
new_width = width + left + right
return image.view(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values
def _pad_with_vector_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: List[float],
padding_mode: str = "constant",
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).view(-1, 1, 1)
if top > 0:
output[..., :top, :] = fill
if left > 0:
output[..., :, :left] = fill
if bottom > 0:
output[..., -bottom:, :] = fill
if right > 0:
output[..., :, -right:] = fill
return output
def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding_mode: str = "constant",
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if fill is None:
fill = 0
if isinstance(fill, list):
raise ValueError("Non-scalar fill value is not supported")
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
return output
def pad_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
padding: Union[int, List[int]],
padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]:
# TODO: add support of other padding modes
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
left, right, top, bottom = _parse_pad_padding(padding)
bounding_box = bounding_box.clone()
# this works without conversion since padding only affects xy coordinates
bounding_box[..., 0] += left
bounding_box[..., 1] += top
if format == features.BoundingBoxFormat.XYXY:
bounding_box[..., 2] += left
bounding_box[..., 3] += top
height, width = image_size
height += top + bottom
width += left + right
return bounding_box, (height, width)
def pad(
inpt: features.InputTypeJIT,
padding: Union[int, List[int]],
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
else:
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop
def crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
)
# Crop or implicit pad if left and/or top have negative values:
bounding_box[..., 0::2] -= left
bounding_box[..., 1::2] -= top
return (
convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
),
(height, width),
)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width)
else:
return crop_image_pil(inpt, top, left, height, width)
def perspective_image_tensor(
image: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def perspective_image_pil(
image: PIL.Image.Image,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: features.FillTypeJIT = None,
) -> PIL.Image.Image:
return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
def perspective_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
perspective_coeffs: List[float],
) -> torch.Tensor:
if len(perspective_coeffs) != 8:
raise ValueError("Argument perspective_coeffs should have 8 float values")
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
# perspective_coeffs are computed as endpoint -> start point
# We have to invert perspective_coeffs for bboxes:
# (x, y) - end point and (x_out, y_out) - start point
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# and we would like to get:
# x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# and compute inv_coeffs in terms of coeffs
denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
if denom == 0:
raise RuntimeError(
f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
f"Denominator is zero, denom={denom}"
)
inv_coeffs = [
(perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
(perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
(-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
(perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
(-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
device=device,
)
theta2 = torch.tensor(
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
numer_points = torch.matmul(points, theta1.T)
denom_points = torch.matmul(points, theta2.T)
transformed_points = numer_points / denom_points
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
# out_bboxes should be of shape [N boxes, 4]
return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
def perspective_mask(
mask: torch.Tensor,
perspective_coeffs: List[float],
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = perspective_image_tensor(
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST, fill=fill
)
if needs_squeeze:
output = output.squeeze(0)
return output
def perspective(
inpt: features.InputTypeJIT,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
else:
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def elastic_image_tensor(
image: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def elastic_image_pil(
image: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=image.mode)
def elastic_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
displacement: torch.Tensor,
) -> torch.Tensor:
# TODO: add in docstring about approximation we are doing for grid inversion
displacement = displacement.to(bounding_box.device)
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add image_size arg and check displacement shape
image_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid - displacement
# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points:
t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
def elastic_mask(
mask: torch.Tensor,
displacement: torch.Tensor,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
if needs_squeeze:
output = output.squeeze(0)
return output
def elastic(
inpt: features.InputTypeJIT,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
else:
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
elastic_transform = elastic
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
return [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
def _center_crop_compute_crop_anchor(
crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_tensor(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_tensor(image)
if crop_width == image_width and crop_height == image_height:
return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width)
@torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_pil(image)
if crop_width == image_width and crop_height == image_height:
return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
def center_crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = center_crop_image_tensor(image=mask, output_size=output_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature):
return inpt.center_crop(output_size)
else:
return center_crop_image_pil(inpt, output_size)
def resized_crop_image_tensor(
image: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> torch.Tensor:
image = crop_image_tensor(image, top, left, height, width)
return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil(
image: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> PIL.Image.Image:
image = crop_image_pil(image, top, left, height, width)
return resize_image_pil(image, size, interpolation=interpolation)
def resized_crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
return resize_bounding_box(bounding_box, (height, width), size)
def resized_crop_mask(
mask: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
mask = crop_mask(mask, top, left, height, width)
return resize_mask(mask, size)
def resized_crop(
inpt: features.InputTypeJIT,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
else:
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = [int(size), int(size)]
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
return size
def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(image)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(image, [crop_height, crop_width])
return tl, tr, bl, br, center
@torch.jit.unused
def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(image)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(image, [crop_height, crop_width])
return tl, tr, bl, br, center
def five_crop(
inpt: features.ImageTypeJIT, size: List[int]
) -> Tuple[
features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT
]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(image, size)
if vertical_flip:
image = vertical_flip_image_tensor(image)
else:
image = horizontal_flip_image_tensor(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
@torch.jit.unused
def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(image, size)
if vertical_flip:
image = vertical_flip_image_pil(image)
else:
image = horizontal_flip_image_pil(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output]
return output
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
from typing import cast, List, Optional, Tuple
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
else: # isinstance(image, PIL.Image.Image)
channels, height, width = get_dimensions_image_pil(image)
return channels, height, width
# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
# detailed above.
def get_dimensions(image: features.ImageTypeJIT) -> List[int]:
return list(get_chw(image))
def get_num_channels(image: features.ImageTypeJIT) -> int:
num_channels, *_ = get_chw(image)
return num_channels
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageTypeJIT) -> List[int]:
_, *size = get_chw(image)
return size
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
xywh = xyxy.clone()
xywh[..., 2:] -= xywh[..., :2]
return xywh
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return torch.stack((x1, y1, x2, y2), dim=-1).to(cxcywh.dtype)
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return torch.stack((cx, cy, w, h), dim=-1).to(xyxy.dtype)
def convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = True
) -> torch.Tensor:
if new_format == old_format:
if copy:
return bounding_box.clone()
else:
return bounding_box
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box)
elif old_format == BoundingBoxFormat.CXCYWH:
bounding_box = _cxcywh_to_xyxy(bounding_box)
if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box)
return bounding_box
def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
# TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format.
# Not sure if they yield equivalent results.
xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = _split_alpha(image)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return image
def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
_rgb_to_gray = _FT.rgb_to_grayscale
def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
) -> torch.Tensor:
if new_color_space == old_color_space:
if copy:
return image.clone()
else:
return image
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(_gray_to_rgb(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
return _strip_alpha(image)
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(_rgb_to_gray(image))
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
else:
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
_COLOR_SPACE_TO_PIL_MODE = {
ColorSpace.GRAY: "L",
ColorSpace.GRAY_ALPHA: "LA",
ColorSpace.RGB: "RGB",
ColorSpace.RGB_ALPHA: "RGBA",
}
@torch.jit.unused
def convert_color_space_image_pil(
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image:
old_mode = image.mode
try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
if not copy and image.mode == new_mode:
return image
return image.convert(new_mode)
def convert_color_space(
inpt: features.ImageTypeJIT,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
)
elif isinstance(inpt, features.Image):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment