Unverified Commit fbb4cc54 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove torchvision.prototype module and related tests / CI from release branch (#7983)

parent a90e5846
from functools import partial
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
from torchvision.utils import _log_api_usage_once
__all__ = (
"RaftStereo",
"raft_stereo_base",
"raft_stereo_realtime",
"Raft_Stereo_Base_Weights",
"Raft_Stereo_Realtime_Weights",
)
class BaseEncoder(raft.FeatureEncoder):
"""Base encoder for FeatureEncoder and ContextEncoder in which weight may be shared.
See the Raft-Stereo paper section 4.6 on backbone part.
"""
def __init__(
self,
*,
block: Callable[..., nn.Module] = ResidualBlock,
layers: Tuple[int, int, int, int] = (64, 64, 96, 128),
strides: Tuple[int, int, int, int] = (2, 1, 2, 2),
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d,
):
# We use layers + (256,) because raft.FeatureEncoder require 5 layers
# but here we will set the last conv layer to identity
super().__init__(block=block, layers=layers + (256,), strides=strides, norm_layer=norm_layer)
# Base encoder don't have the last conv layer of feature encoder
self.conv = nn.Identity()
self.output_dim = layers[3]
num_downsampling = sum([x - 1 for x in strides])
self.downsampling_ratio = 2 ** (num_downsampling)
class FeatureEncoder(nn.Module):
"""Feature Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Context Encoder.
The FeatureEncoder takes concatenation of left and right image as input. It produces feature embedding that later
will be used to construct correlation volume.
"""
def __init__(
self,
base_encoder: BaseEncoder,
output_dim: int = 256,
shared_base: bool = False,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
if not shared_base:
self.residual_block: nn.Module = nn.Identity()
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=1)
else:
# If we share base encoder weight for Feature and Context Encoder
# we need to add residual block with InstanceNorm2d and change the kernel size for conv layer
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L35-L37
self.residual_block = block(base_dim, base_dim, norm_layer=nn.InstanceNorm2d, stride=1)
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.base_encoder(x)
x = self.residual_block(x)
x = self.conv(x)
return x
class MultiLevelContextEncoder(nn.Module):
"""Context Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Feature Encoder.
The ContextEncoder takes left image as input, and it outputs concatenated hidden_states and contexts.
In Raft-Stereo we have multi level GRUs and this context encoder will also multi outputs (list of Tensor)
that correspond to each GRUs.
Take note that the length of "out_with_blocks" parameter represent the number of GRU's level.
args:
base_encoder (nn.Module): The base encoder part that can have a shared weight with feature_encoder's
base_encoder because they have same architecture.
out_with_blocks (List[bool]): The length represent the number of GRU's level (length of output), and
if the element is True then the output layer on that position will have additional block
output_dim (int): The dimension of output on each level (default: 256)
block (Callable[..., nn.Module]): The type of basic block used for downsampling and output layer
(default: ResidualBlock)
"""
def __init__(
self,
base_encoder: nn.Module,
out_with_blocks: List[bool],
output_dim: int = 256,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.num_level = len(out_with_blocks)
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
self.downsample_and_out_layers = nn.ModuleList(
[
nn.ModuleDict(
{
"downsampler": self._make_downsampler(block, base_dim, base_dim) if i > 0 else nn.Identity(),
"out_hidden_state": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
"out_context": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
}
)
for i in range(self.num_level)
]
)
def _make_out_layer(self, in_channels, out_channels, with_block=True, block=ResidualBlock):
layers = []
if with_block:
layers.append(block(in_channels, in_channels, norm_layer=nn.BatchNorm2d, stride=1))
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
return nn.Sequential(*layers)
def _make_downsampler(self, block, in_channels, out_channels):
block1 = block(in_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=2)
block2 = block(out_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=1)
return nn.Sequential(block1, block2)
def forward(self, x: Tensor) -> List[Tensor]:
x = self.base_encoder(x)
outs = []
for layer_dict in self.downsample_and_out_layers:
x = layer_dict["downsampler"](x)
outs.append(torch.cat([layer_dict["out_hidden_state"](x), layer_dict["out_context"](x)], dim=1))
return outs
class ConvGRU(raft.ConvGRU):
"""Convolutional Gru unit."""
# Modified from raft.ConvGRU to accept pre-convolved contexts,
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/update.py#L23
def forward(self, h: Tensor, x: Tensor, context: List[Tensor]) -> Tensor: # type: ignore[override]
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) + context[0])
r = torch.sigmoid(self.convr(hx) + context[1])
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)) + context[2])
h = (1 - z) * h + z * q
return h
class MultiLevelUpdateBlock(nn.Module):
"""The update block which contains the motion encoder and grus
It must expose a ``hidden_dims`` attribute which is the hidden dimension size of its gru blocks
"""
def __init__(self, *, motion_encoder: MotionEncoder, hidden_dims: List[int]):
super().__init__()
self.motion_encoder = motion_encoder
# The GRU input size is the size of previous level hidden_dim plus next level hidden_dim
# if this is the first gru, then we replace previous level with motion_encoder output channels
# for the last GRU, we don't add the next level hidden_dim
gru_input_dims = []
for i in range(len(hidden_dims)):
input_dim = hidden_dims[i - 1] if i > 0 else motion_encoder.out_channels
if i < len(hidden_dims) - 1:
input_dim += hidden_dims[i + 1]
gru_input_dims.append(input_dim)
self.grus = nn.ModuleList(
[
ConvGRU(input_size=gru_input_dims[i], hidden_size=hidden_dims[i], kernel_size=3, padding=1)
# Ideally we should reverse the direction during forward to use the gru with the smallest resolution
# first however currently there is no way to reverse a ModuleList that is jit script compatible
# hence we reverse the ordering of self.grus on the constructor instead
# see: https://github.com/pytorch/pytorch/issues/31772
for i in reversed(list(range(len(hidden_dims))))
]
)
self.hidden_dims = hidden_dims
def forward(
self,
hidden_states: List[Tensor],
contexts: List[List[Tensor]],
corr_features: Tensor,
disparity: Tensor,
level_processed: List[bool],
) -> List[Tensor]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states
# see self.grus on the constructor for more detail
for reverse_i, gru in enumerate(self.grus):
i = len(self.grus) - 1 - reverse_i
if level_processed[i]:
# X is concatenation of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist).
if i == 0:
features = self.motion_encoder(disparity, corr_features)
else:
# 2x downsampled features from larger hidden states
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
if i < len(self.grus) - 1:
# Concat with 2x upsampled features from smaller hidden states
_, _, h, w = hidden_states[i + 1].shape
features = torch.cat(
[
features,
F.interpolate(
hidden_states[i + 1], size=(2 * h, 2 * w), mode="bilinear", align_corners=True
),
],
dim=1,
)
hidden_states[i] = gru(hidden_states[i], features, contexts[i])
# NOTE: For slow-fast gru, we don't always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward
return hidden_states
class MaskPredictor(raft.MaskPredictor):
"""Mask predictor to be used when upsampling the predicted disparity."""
# We add out_channels compared to raft.MaskPredictor
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
super(raft.MaskPredictor, self).__init__()
self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
self.conv = nn.Conv2d(hidden_size, out_channels, kernel_size=1, padding=0)
self.multiplier = multiplier
class CorrPyramid1d(nn.Module):
"""Row-wise correlation pyramid.
Create a row-wise correlation pyramid with ``num_levels`` level from the outputs of the feature encoder,
this correlation pyramid will later be used as index to create correlation features using CorrBlock1d.
"""
def __init__(self, num_levels: int = 4):
super().__init__()
self.num_levels = num_levels
def forward(self, fmap1: Tensor, fmap2: Tensor) -> List[Tensor]:
"""Build the correlation pyramid from two feature maps.
The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row.
The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
to build the correlation pyramid.
"""
torch._assert(
fmap1.shape == fmap2.shape,
f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)",
)
batch_size, num_channels, h, w = fmap1.shape
fmap1 = fmap1.view(batch_size, num_channels, h, w)
fmap2 = fmap2.view(batch_size, num_channels, h, w)
corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
corr = corr.view(batch_size, h, w, 1, w)
corr_volume = corr / torch.sqrt(torch.tensor(num_channels, device=corr.device))
corr_volume = corr_volume.reshape(batch_size * h * w, 1, 1, w)
corr_pyramid = [corr_volume]
for _ in range(self.num_levels - 1):
corr_volume = F.avg_pool2d(corr_volume, kernel_size=(1, 2), stride=(1, 2))
corr_pyramid.append(corr_volume)
return corr_pyramid
class CorrBlock1d(nn.Module):
"""The row-wise correlation block.
Use indexes from correlation pyramid to create correlation features.
The "indexing" of a given centroid pixel x' is done by concatenating its surrounding row neighbours
within radius
"""
def __init__(self, *, num_levels: int = 4, radius: int = 4):
super().__init__()
self.radius = radius
self.out_channels = num_levels * (2 * radius + 1)
def forward(self, centroids_coords: Tensor, corr_pyramid: List[Tensor]) -> Tensor:
"""Return correlation features by indexing from the pyramid."""
neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, device=centroids_coords.device)
di = di.view(1, 1, neighborhood_side_len, 1).to(centroids_coords.device)
batch_size, _, h, w = centroids_coords.shape # _ = 2 but we only use the first one
# We only consider 1d and take the first dim only
centroids_coords = centroids_coords[:, :1].permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 1)
indexed_pyramid = []
for corr_volume in corr_pyramid:
x0 = centroids_coords + di # end shape is (batch_size * h * w, 1, side_len, 1)
y0 = torch.zeros_like(x0)
sampling_coords = torch.cat([x0, y0], dim=-1)
indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
batch_size, h, w, -1
)
indexed_pyramid.append(indexed_corr_volume)
centroids_coords = centroids_coords / 2
corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
expected_output_shape = (batch_size, self.out_channels, h, w)
torch._assert(
corr_features.shape == expected_output_shape,
f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}",
)
return corr_features
class RaftStereo(nn.Module):
def __init__(
self,
*,
feature_encoder: FeatureEncoder,
context_encoder: MultiLevelContextEncoder,
corr_pyramid: CorrPyramid1d,
corr_block: CorrBlock1d,
update_block: MultiLevelUpdateBlock,
disparity_head: nn.Module,
mask_predictor: Optional[nn.Module] = None,
slow_fast: bool = False,
):
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts:
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
- one part will be used to initialize the hidden state of the recurrent unit of
the ``update_block``
corr_pyramid (CorrPyramid1d): Module to build the correlation pyramid from feature encoder output
corr_block (CorrBlock1d): The correlation block, which uses the correlation pyramid indexes
to create correlation features. It takes the coordinate of the centroid pixel and correlation pyramid
as input and returns the correlation features.
It must expose an ``out_channels`` attribute.
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted disparity. It outputs an updated hidden state
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
on section 3.4 for more detail.
"""
super().__init__()
_log_api_usage_once(self)
# This indicates that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self.output_channels = 1
self.feature_encoder = feature_encoder
self.context_encoder = context_encoder
self.base_downsampling_ratio = feature_encoder.base_downsampling_ratio
self.num_level = self.context_encoder.num_level
self.corr_pyramid = corr_pyramid
self.corr_block = corr_block
self.update_block = update_block
self.disparity_head = disparity_head
self.mask_predictor = mask_predictor
hidden_dims = self.update_block.hidden_dims
# Follow the original implementation to do pre convolution on the context
# See: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L32
self.context_convs = nn.ModuleList(
[nn.Conv2d(hidden_dims[i], hidden_dims[i] * 3, kernel_size=3, padding=1) for i in range(self.num_level)]
)
self.slow_fast = slow_fast
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
) -> List[Tensor]:
"""
Return disparity predictions on every iteration as a list of Tensor.
args:
left_image (Tensor): The input left image with layout B, C, H, W
right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
"""
batch_size, _, h, w = left_image.shape
torch._assert(
(h, w) == right_image.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
)
torch._assert(
(h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0),
f"input image H and W should be divisible by {self.base_downsampling_ratio}, instead got H={h} and W={w}",
)
fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
torch._assert(
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
f"The feature encoder should downsample H and W by {self.base_downsampling_ratio}",
)
corr_pyramid = self.corr_pyramid(fmap1, fmap2)
# Multi level contexts
context_outs = self.context_encoder(left_image)
hidden_dims = self.update_block.hidden_dims
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
hidden_states: List[Tensor] = []
contexts: List[List[Tensor]] = []
for i, context_conv in enumerate(self.context_convs):
# As in the original paper, the actual output of the context encoder is split in 2 parts:
# - one part is used to initialize the hidden state of the recurent units of the update block
# - the rest is the "actual" context.
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
hidden_states.append(torch.tanh(hidden_state))
contexts.append(
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
# JIT and thus we have to keep the wrong annotation here and silence mypy.
torch.split( # type: ignore[arg-type]
context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1
)
)
_, Cf, Hf, Wf = fmap1.shape
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
# We use flow_init for cascade inference
if flow_init is not None:
coords1 = coords1 + flow_init
disparity_predictions = []
for _ in range(num_iters):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)
disparity = coords1 - coords0
if self.slow_fast:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for i in range(1, self.num_level):
# We only processed the smallest i levels
level_processed = [False] * (self.num_level - i) + [True] * i
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=level_processed
)
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
)
# Take the largest hidden_state to get the disparity
hidden_state = hidden_states[0]
delta_disparity = self.disparity_head(hidden_state)
# in stereo mode, project disparity onto epipolar
delta_disparity[:, 1] = 0.0
coords1 = coords1 + delta_disparity
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_disparity = upsample_flow(
(coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
)
disparity_predictions.append(upsampled_disparity[:, :1])
return disparity_predictions
def _raft_stereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
shared_encoder_weight: bool,
# Feature encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
# Context encoder
context_encoder_layers: Tuple[int, int, int, int, int],
context_encoder_strides: Tuple[int, int, int, int],
# if the `out_with_blocks` param of the context_encoder is True, then
# the particular output on that level position will have additional `context_encoder_block` layer
context_encoder_out_with_blocks: List[bool],
context_encoder_block: Callable[..., nn.Module],
# Correlation block
corr_num_levels: int,
corr_radius: int,
# Motion encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Update block
update_block_hidden_dims: List[int],
# Flow Head
flow_head_hidden_size: int,
# Mask predictor
mask_predictor_hidden_size: int,
use_mask_predictor: bool,
slow_fast: bool,
**kwargs,
):
if len(context_encoder_out_with_blocks) != len(update_block_hidden_dims):
raise ValueError(
"Length of context_encoder_out_with_blocks and update_block_hidden_dims must be the same"
+ "because both of them represent the number of GRUs level"
)
if shared_encoder_weight:
if (
feature_encoder_layers[:-1] != context_encoder_layers[:-1]
or feature_encoder_strides != context_encoder_strides
):
raise ValueError(
"If shared_encoder_weight is True, then the feature_encoder_layers[:-1]"
+ " and feature_encoder_strides must be the same with context_encoder_layers[:-1] and context_encoder_strides!"
)
base_encoder = kwargs.pop("base_encoder", None) or BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_base_encoder = base_encoder
context_base_encoder = base_encoder
else:
feature_base_encoder = BaseEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers[:-1],
strides=feature_encoder_strides,
norm_layer=nn.InstanceNorm2d,
)
context_base_encoder = BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
feature_base_encoder,
output_dim=feature_encoder_layers[-1],
shared_base=shared_encoder_weight,
block=feature_encoder_block,
)
context_encoder = kwargs.pop("context_encoder", None) or MultiLevelContextEncoder(
context_base_encoder,
out_with_blocks=context_encoder_out_with_blocks,
output_dim=context_encoder_layers[-1],
block=context_encoder_block,
)
feature_downsampling_ratio = feature_encoder.base_downsampling_ratio
corr_pyramid = kwargs.pop("corr_pyramid", None) or CorrPyramid1d(num_levels=corr_num_levels)
corr_block = kwargs.pop("corr_block", None) or CorrBlock1d(num_levels=corr_num_levels, radius=corr_radius)
motion_encoder = kwargs.pop("motion_encoder", None) or MotionEncoder(
in_channels_corr=corr_block.out_channels,
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
update_block = kwargs.pop("update_block", None) or MultiLevelUpdateBlock(
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
)
# We use the largest scale hidden_dims of update_block to get the predicted disparity
disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
in_channels=update_block_hidden_dims[0],
hidden_size=flow_head_hidden_size,
)
mask_predictor = kwargs.pop("mask_predictor", None)
if use_mask_predictor:
mask_predictor = MaskPredictor(
in_channels=update_block.hidden_dims[0],
hidden_size=mask_predictor_hidden_size,
out_channels=9 * feature_downsampling_ratio * feature_downsampling_ratio,
)
else:
mask_predictor = None
model = RaftStereo(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_pyramid=corr_pyramid,
corr_block=corr_block,
update_block=update_block,
disparity_head=disparity_head,
mask_predictor=mask_predictor,
slow_fast=slow_fast,
**kwargs, # not really needed, all params should be consumed by now
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
class Raft_Stereo_Realtime_Weights(WeightsEnum):
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 8077152,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Kitty2015": {
"3px": 0.9409,
}
},
},
)
DEFAULT = SCENEFLOW_V1
class Raft_Stereo_Base_Weights(WeightsEnum):
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
# Using standard metrics for each dataset
"Kitty2015": {
# Ratio of pixels with difference less than 3px from ground truth
"3px": 0.9426,
},
# For middlebury, ratio of pixels with difference less than 2px from ground truth
# on full, half, and quarter image resolution
"Middlebury2014-val-full": {
"2px": 0.8167,
},
"Middlebury2014-val-half": {
"2px": 0.8741,
},
"Middlebury2014-val-quarter": {
"2px": 0.9064,
},
"ETH3D-val": {
# Ratio of pixels with difference less than 1px from ground truth
"1px": 0.9672,
},
},
},
)
MIDDLEBURY_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Middlebury-test": {
"mae": 1.27,
"1px": 0.9063,
"2px": 0.9526,
"5px": 0.9725,
}
},
},
)
ETH3D_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"ETH3D-test": {
"mae": 0.18,
"1px": 0.9756,
"2px": 0.9956,
}
},
},
)
DEFAULT = MIDDLEBURY_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_realtime(
*, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs
) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
This is the realtime variant of the Raft-Stereo model that is described on the paper section 4.7.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights
:members:
"""
weights = Raft_Stereo_Realtime_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=True,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(2, 1, 2, 2),
context_encoder_out_with_blocks=[True, True],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=True,
**kwargs,
)
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights
:members:
"""
weights = Raft_Stereo_Base_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=False,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(1, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(1, 1, 2, 2),
context_encoder_out_with_blocks=[True, True, False],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=False,
**kwargs,
)
from ._presets import StereoMatching # usort: skip
from ._augment import SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.ops import masks_to_boxes
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _check_interpolation
class SimpleCopyPaste(Transform):
def __init__(
self,
blending: bool = True,
resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.resize_interpolation = _check_interpolation(resize_interpolation)
self.blending = blending
self.antialias = antialias
def _copy_paste(
self,
image: Union[torch.Tensor, tv_tensors.Image],
target: Dict[str, Any],
paste_image: Union[torch.Tensor, tv_tensors.Image],
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = tv_tensors.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = tv_tensors.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = tv_tensors.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = cast(List[int], image.shape[-2:])
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images:
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2.
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=tv_tensors.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
return image, out_target
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[Union[torch.Tensor, tv_tensors.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, tv_tensors.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
elif isinstance(obj, tv_tensors.BoundingBoxes):
bboxes.append(obj)
elif isinstance(obj, tv_tensors.Mask):
masks.append(obj)
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxeses, Masks and Labels or OneHotLabels."
)
targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})
return images, targets
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[torch.Tensor],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, tv_tensors.Image):
flat_sample[i] = tv_tensors.wrap(output_images[c0], like=obj)
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, tv_tensors.BoundingBoxes):
flat_sample[i] = tv_tensors.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1
elif isinstance(obj, tv_tensors.Mask):
flat_sample[i] = tv_tensors.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
flat_sample[i] = tv_tensors.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1
def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
images, targets = self._extract_image_targets(flat_inputs)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images, output_targets = [], []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)
output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
antialias=self.antialias,
)
output_images.append(output_image)
output_targets.append(output_target)
# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_inputs, output_images, output_targets)
return tree_unflatten(flat_inputs, spec)
from typing import Any, Dict, List, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.prototype.tv_tensors import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import (
_FillType,
_get_fill,
_setup_fill_arg,
_setup_size,
get_bounding_boxes,
has_any,
is_pure_tensor,
query_size,
)
class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(
flat_inputs,
PIL.Image.Image,
tv_tensors.Image,
is_pure_tensor,
tv_tensors.Video,
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(flat_inputs, tv_tensors.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel):
raise TypeError(
f"If a BoundingBoxes is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_size(flat_inputs)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
needs_crop = new_height != height or new_width != width
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError:
bounding_boxes = None
if needs_crop and bounding_boxes is not None:
format = bounding_boxes.format
bounding_boxes, canvas_size = F.crop_bounding_boxes(
bounding_boxes.as_subclass(torch.Tensor),
format=format,
top=top,
left=left,
height=new_height,
width=new_width,
)
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
needs_pad = pad_bottom != 0 or pad_right != 0
return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = self._call_kernel(
F.crop,
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, tv_tensors.Mask)):
inpt = tv_tensors.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, tv_tensors.BoundingBoxes):
inpt = tv_tensors.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
)
if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt))
inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
import functools
import warnings
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import is_pure_tensor
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
class PermuteDimensions(Transform):
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.permute(*dims)
class TransposeDimensions(Transform):
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torch import Tensor
from torchvision.transforms.v2 import functional as F, InterpolationMode
from torchvision.transforms.v2.functional._geometry import _check_interpolation
__all__ = ["StereoMatching"]
class StereoMatching(torch.nn.Module):
def __init__(
self,
*,
use_gray_scale: bool = False,
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
# pacify mypy
self.resize_size: Union[None, List]
if resize_size is not None:
self.resize_size = list(resize_size)
else:
self.resize_size = None
self.mean = list(mean)
self.std = list(std)
self.interpolation = _check_interpolation(interpolation)
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
def _process_image(img: PIL.Image.Image) -> Tensor:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
if self.resize_size is not None:
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=False)
if self.use_gray_scale is True:
img = F.rgb_to_grayscale(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
img = img.contiguous()
return img
left_image = _process_image(left_image)
right_image = _process_image(right_image)
return left_image, right_image
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
from typing import Any, Dict
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import Transform
class LabelToOneHot(Transform):
_transformed_types = (proto_tv_tensors.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return proto_tv_tensors.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
from ._label import Label, OneHotLabel
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(TVTensor):
categories: Optional[Sequence[str]]
@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
def __new__(
cls: Type[L],
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> L:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)
@classmethod
def from_category(
cls: Type[L],
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> L:
return cls(categories.index(category), categories=categories, **kwargs)
class Label(_LabelBase):
def to_categories(self) -> Any:
if self.categories is None:
raise RuntimeError("Label does not have categories")
return tree_map(lambda idx: self.categories[idx], self.tolist())
class OneHotLabel(_LabelBase):
def __new__(
cls,
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
return one_hot_label
import collections.abc
import difflib
import io
import mmap
import platform
from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union
import numpy as np
import torch
from torchvision._utils import sequence_to_str
__all__ = [
"add_suggestion",
"fromfile",
"ReadOnlyTensorBuffer",
]
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
D = TypeVar("D")
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)
buffer: Union[memoryview, bytearray]
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (AttributeError, PermissionError, io.UnsupportedOperation):
buffer = _read_mutable_buffer_fallback(file, count, item_size)
else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
class ReadOnlyTensorBuffer:
def __init__(self, tensor: torch.Tensor) -> None:
self._memory = memoryview(tensor.numpy())
self._cursor: int = 0
def tell(self) -> int:
return self._cursor
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
self._cursor = offset
elif whence == io.SEEK_CUR:
self._cursor += offset
pass
elif whence == io.SEEK_END:
self._cursor = len(self._memory) + offset
else:
raise ValueError(
f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, "
f"but got {repr(whence)} instead"
)
return self.tell()
def read(self, size: int = -1) -> bytes:
cursor = self.tell()
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment