Unverified Commit 48b1edff authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove prototype area for 0.19 (#8491)

parent f44f20cf
from .raft_stereo import *
from .crestereo import *
import math
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
all = (
"CREStereo",
"CREStereo_Base_Weights",
"crestereo_base",
)
class ConvexMaskPredictor(nn.Module):
def __init__(
self,
*,
in_channels: int,
hidden_size: int,
upsample_factor: int,
multiplier: float = 0.25,
) -> None:
super().__init__()
self.mask_head = nn.Sequential(
Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3),
# https://arxiv.org/pdf/2003.12039.pdf (Annex section B) for the
# following convolution output size
nn.Conv2d(hidden_size, upsample_factor**2 * 9, 1, padding=0),
)
self.multiplier = multiplier
def forward(self, x: Tensor) -> Tensor:
x = self.mask_head(x) * self.multiplier
return x
def get_correlation(
left_feature: Tensor,
right_feature: Tensor,
window_size: Tuple[int, int] = (3, 3),
dilate: Tuple[int, int] = (1, 1),
) -> Tensor:
"""Function that computes a correlation product between the left and right features.
The correlation is computed in a sliding window fashion, namely the left features are fixed
and for each ``(i, j)`` location we compute the correlation with a sliding window anchored in
``(i, j)`` from the right feature map. The sliding window selects pixels obtained in the range of the sliding
window; i.e ``(i - window_size // 2, i + window_size // 2)`` respectively ``(j - window_size // 2, j + window_size // 2)``.
"""
B, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pad_y, pad_x = window_size[0] // 2 * di_y, window_size[1] // 2 * di_x
right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate")
# in order to vectorize the correlation computation over all pixel candidates
# we create multiple shifted right images which we stack on an extra dimension
right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate)
# torch unfold returns a tensor of shape [B, flattened_values, n_selections]
right_padded = right_padded.permute(0, 2, 1)
# we consider rehsape back into [B, n_views, C, H, W]
right_padded = right_padded.reshape(B, (window_size[0] * window_size[1]), C, H, W)
# we expand the left features for broadcasting
left_feature = left_feature.unsqueeze(1)
# this will compute an element product of between [B, 1, C, H, W] * [B, n_views, C, H, W]
# to obtain correlations over the pixel candidates we perform a mean on the C dimension
correlation = torch.mean(left_feature * right_padded, dim=2, keepdim=False)
# the final correlation tensor shape will be [B, n_views, H, W]
# where on the i-th position of the n_views dimension we will have
# the correlation value between the left pixel
# and the i-th candidate on the right feature map
return correlation
def _check_window_specs(
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
if not np.prod(search_window_1d) == np.prod(search_window_2d):
raise ValueError(
f"The 1D and 2D windows should contain the same number of elements. "
f"1D shape: {search_window_1d} 2D shape: {search_window_2d}"
)
if not np.prod(search_window_1d) % 2 == 1:
raise ValueError(
f"Search windows should contain an odd number of elements in them."
f"Window of shape {search_window_1d} has {np.prod(search_window_1d)} elements."
)
if not any(size == 1 for size in search_window_1d):
raise ValueError(f"The 1D search window should have at least one size equal to 1. 1D shape: {search_window_1d}")
if any(size == 1 for size in search_window_2d):
raise ValueError(
f"The 2D search window should have all dimensions greater than 1. 2D shape: {search_window_2d}"
)
if any(dilate < 1 for dilate in search_dilate_1d):
raise ValueError(
f"The 1D search dilation should have all elements equal or greater than 1. 1D shape: {search_dilate_1d}"
)
if any(dilate < 1 for dilate in search_dilate_2d):
raise ValueError(
f"The 2D search dilation should have all elements equal greater than 1. 2D shape: {search_dilate_2d}"
)
class IterativeCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
self.search_pixels = np.prod(search_window_1d)
self.groups = groups
# two selection tables for dealing with the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
def forward(self, left_feature: Tensor, right_feature: Tensor, flow: Tensor, window_type: str = "1d") -> Tensor:
"""Function that computes 1 pass of non-offsetted Group-Wise correlation"""
coords = make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
# we offset the coordinate grid in the flow direction
coords = coords + flow
coords = coords.permute(0, 2, 3, 1)
# resample right features according to off-setted grid
right_feature = grid_sample(right_feature, coords, mode="bilinear", align_corners=True)
# use_small_patch is a flag by which we decide on how many axes
# we perform candidate search. See section 3.1 ``Deformable search window`` & Figure 4 in the paper.
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
# chunking the left and right feature to perform group-wise correlation
# mechanism similar to GroupNorm. See section 3.1 ``Group-wise correlation``.
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
correlations = []
# this boils down to rather than performing the correlation product
# over the entire C dimensions, we use subsets of C to get multiple correlation sets
for i in range(len(patch_size_list)):
correlation = get_correlation(left_groups[i], right_groups[i], patch_size_list[i], dilate_size_list[i])
correlations.append(correlation)
final_correlations = torch.cat(correlations, dim=1)
return final_correlations
class AttentionOffsetCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
attention_module: Optional[nn.Module] = None,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
# convert to python scalar
self.search_pixels = int(np.prod(search_window_1d))
self.groups = groups
# two selection tables for dealing with the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
self.attention_module = attention_module
def forward(
self,
left_feature: Tensor,
right_feature: Tensor,
flow: Tensor,
extra_offset: Tensor,
window_type: str = "1d",
) -> Tensor:
"""Function that computes 1 pass of offsetted Group-Wise correlation
If the class was provided with an attention layer, the left and right feature maps
will be passed through a transformer first
"""
B, C, H, W = left_feature.shape
if self.attention_module is not None:
# prepare for transformer required input shapes
left_feature = left_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
right_feature = right_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
# this can be either self attention or cross attention, hence the tuple return
left_feature, right_feature = self.attention_module(left_feature, right_feature)
left_feature = left_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
right_feature = right_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
num_search_candidates = self.search_pixels
# for each pixel (i, j) we have a number of search candidates
# thus, for each candidate we should have an X-axis and Y-axis offset value
extra_offset = extra_offset.reshape(B, num_search_candidates, 2, H, W).permute(0, 1, 3, 4, 2)
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
group_channels = C // self.groups
correlations = []
for i in range(len(patch_size_list)):
left_group, right_group = left_groups[i], right_groups[i]
patch_size, dilate = patch_size_list[i], dilate_size_list[i]
di_y, di_x = dilate
ps_y, ps_x = patch_size
# define the search based on the window patch shape
ry, rx = ps_y // 2 * di_y, ps_x // 2 * di_x
# base offsets for search (i.e. where to look on the search index)
x_grid, y_grid = torch.meshgrid(
torch.arange(-rx, rx + 1, di_x), torch.arange(-ry, ry + 1, di_y), indexing="xy"
)
x_grid, y_grid = x_grid.to(flow.device), y_grid.to(flow.device)
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in (0, 2, 3):
offsets = offsets.unsqueeze(d)
# extra offsets for search (i.e. deformed search indexes. Similar concept to deformable convolutions)
offsets = offsets + extra_offset
coords = (
make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
+ flow
)
coords = coords.permute(0, 2, 3, 1).unsqueeze(1)
coords = coords + offsets
coords = coords.reshape(B, -1, W, 2)
right_group = grid_sample(right_group, coords, mode="bilinear", align_corners=True)
# we do not need to perform any window shifting because the grid sample op
# will return a multi-view right based on the num_search_candidates dimension in the offsets
right_group = right_group.reshape(B, group_channels, -1, H, W)
left_group = left_group.reshape(B, group_channels, -1, H, W)
correlation = torch.mean(left_group * right_group, dim=1)
correlations.append(correlation)
final_correlation = torch.cat(correlations, dim=1)
return final_correlation
class AdaptiveGroupCorrelationLayer(nn.Module):
"""
Container for computing various correlation types between a left and right feature map.
This module does not contain any optimisable parameters, it's solely a collection of ops.
We wrap in a nn.Module for torch.jit.script compatibility
Adaptive Group Correlation operations from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf
Canonical reference implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/corr.py
"""
def __init__(
self,
iterative_correlation_layer: IterativeCorrelationLayer,
attention_offset_correlation_layer: AttentionOffsetCorrelationLayer,
) -> None:
super().__init__()
self.iterative_correlation_layer = iterative_correlation_layer
self.attention_offset_correlation_layer = attention_offset_correlation_layer
def forward(
self,
left_features: Tensor,
right_features: Tensor,
flow: torch.Tensor,
extra_offset: Optional[Tensor],
window_type: str = "1d",
iter_mode: bool = False,
) -> Tensor:
if iter_mode or extra_offset is None:
corr = self.iterative_correlation_layer(left_features, right_features, flow, window_type)
else:
corr = self.attention_offset_correlation_layer(
left_features, right_features, flow, extra_offset, window_type
) # type: ignore
return corr
def elu_feature_map(x: Tensor) -> Tensor:
"""Elu feature map operation from: https://arxiv.org/pdf/2006.16236.pdf"""
return F.elu(x) + 1
class LinearAttention(nn.Module):
"""
Linear attention operation from: https://arxiv.org/pdf/2006.16236.pdf
Canonical implementation reference: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, eps: float = 1e-6, feature_map_fn: Callable[[Tensor], Tensor] = elu_feature_map) -> None:
super().__init__()
self.eps = eps
self.feature_map_fn = feature_map_fn
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values (torch.Tensor): [N, S1, H, D]
"""
queries = self.feature_map_fn(queries)
keys = self.feature_map_fn(keys)
if q_mask is not None:
queries = queries * q_mask[:, :, None, None]
if kv_mask is not None:
keys = keys * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
# mitigates fp16 overflows
values_length = values.shape[1]
values = values / values_length
kv = torch.einsum("NSHD, NSHV -> NHDV", keys, values)
z = 1 / (torch.einsum("NLHD, NHD -> NLH", queries, keys.sum(dim=1)) + self.eps)
# rescale at the end to account for fp16 mitigation
queried_values = torch.einsum("NLHD, NHDV, NLH -> NLHV", queries, kv, z) * values_length
return queried_values
class SoftmaxAttention(nn.Module):
"""
A simple softmax attention operation
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, dropout: float = 0.0) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Computes classical softmax full-attention between all queries and keys.
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values: [N, S1, H, D]
"""
scale_factor = 1.0 / queries.shape[3] ** 0.5 # irsqrt(D) scaling
queries = queries * scale_factor
qk = torch.einsum("NLHD, NSHD -> NLSH", queries, keys)
if kv_mask is not None and q_mask is not None:
qk.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf"))
attention = torch.softmax(qk, dim=2)
attention = self.dropout(attention)
queried_values = torch.einsum("NLSH, NSHD -> NLHD", attention, values)
return queried_values
class PositionalEncodingSine(nn.Module):
"""
Sinusoidal positional encodings
Using the scaling term from https://github.com/megvii-research/CREStereo/blob/master/nets/attention/position_encoding.py
Reference implementation from https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/position_encoding.py#L28-L48
"""
def __init__(self, dim_model: int, max_size: int = 256) -> None:
super().__init__()
self.dim_model = dim_model
self.max_size = max_size
# pre-registered for memory efficiency during forward pass
pe = self._make_pe_of_size(self.max_size)
self.register_buffer("pe", pe)
def _make_pe_of_size(self, size: int) -> Tensor:
pe = torch.zeros((self.dim_model, *(size, size)), dtype=torch.float32)
y_positions = torch.ones((size, size)).cumsum(0).float().unsqueeze(0)
x_positions = torch.ones((size, size)).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(torch.arange(0.0, self.dim_model // 2, 2) * (-math.log(10000.0) / self.dim_model // 2))
div_term = div_term[:, None, None]
pe[0::4, :, :] = torch.sin(x_positions * div_term)
pe[1::4, :, :] = torch.cos(x_positions * div_term)
pe[2::4, :, :] = torch.sin(y_positions * div_term)
pe[3::4, :, :] = torch.cos(y_positions * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: [B, C, H, W]
Returns:
x: [B, C, H, W]
"""
torch._assert(
len(x.shape) == 4,
f"PositionalEncodingSine requires a 4-D dimensional input. Provided tensor is of shape {x.shape}",
)
B, C, H, W = x.shape
return x + self.pe[:, :, :H, :W] # type: ignore
class LocalFeatureEncoderLayer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super().__init__()
self.attention_op = attention_module()
if not isinstance(self.attention_op, (LinearAttention, SoftmaxAttention)):
raise ValueError(
f"attention_module must be an instance of LinearAttention or SoftmaxAttention. Got {type(self.attention_op)}"
)
self.dim_head = dim_model // num_heads
self.num_heads = num_heads
# multi-head attention
self.query_proj = nn.Linear(dim_model, dim_model, bias=False)
self.key_proj = nn.Linear(dim_model, dim_model, bias=False)
self.value_proj = nn.Linear(dim_model, dim_model, bias=False)
self.merge = nn.Linear(dim_model, dim_model, bias=False)
# feed forward network
self.ffn = nn.Sequential(
nn.Linear(dim_model * 2, dim_model * 2, bias=False),
nn.ReLU(),
nn.Linear(dim_model * 2, dim_model, bias=False),
)
# norm layers
self.attention_norm = nn.LayerNorm(dim_model)
self.ffn_norm = nn.LayerNorm(dim_model)
def forward(
self, x: Tensor, source: Tensor, x_mask: Optional[Tensor] = None, source_mask: Optional[Tensor] = None
) -> Tensor:
"""
Args:
x (torch.Tensor): [B, S1, D]
source (torch.Tensor): [B, S2, D]
x_mask (torch.Tensor): [B, S1] (optional)
source_mask (torch.Tensor): [B, S2] (optional)
"""
B, S, D = x.shape
queries, keys, values = x, source, source
queries = self.query_proj(queries).reshape(B, S, self.num_heads, self.dim_head)
keys = self.key_proj(keys).reshape(B, S, self.num_heads, self.dim_head)
values = self.value_proj(values).reshape(B, S, self.num_heads, self.dim_head)
# attention operation
message = self.attention_op(queries, keys, values, x_mask, source_mask)
# concatenating attention heads together before passing through projection layer
message = self.merge(message.reshape(B, S, D))
message = self.attention_norm(message)
# ffn operation
message = self.ffn(torch.cat([x, message], dim=2))
message = self.ffn_norm(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_directions: List[str],
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super(LocalFeatureTransformer, self).__init__()
self.attention_module = attention_module
self.attention_directions = attention_directions
for direction in attention_directions:
if direction not in ["self", "cross"]:
raise ValueError(
f"Attention direction {direction} unsupported. LocalFeatureTransformer accepts only ``attention_type`` in ``[self, cross]``."
)
self.layers = nn.ModuleList(
[
LocalFeatureEncoderLayer(dim_model=dim_model, num_heads=num_heads, attention_module=attention_module)
for _ in attention_directions
]
)
def forward(
self,
left_features: Tensor,
right_features: Tensor,
left_mask: Optional[Tensor] = None,
right_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
left_mask (torch.Tensor): [N, S1] (optional)
right_mask (torch.Tensor): [N, S2] (optional)
Returns:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
"""
torch._assert(
left_features.shape[2] == right_features.shape[2],
f"left_features and right_features should have the same embedding dimensions. left_features: {left_features.shape[2]} right_features: {right_features.shape[2]}",
)
for idx, layer in enumerate(self.layers):
attention_direction = self.attention_directions[idx]
if attention_direction == "self":
left_features = layer(left_features, left_features, left_mask, left_mask)
right_features = layer(right_features, right_features, right_mask, right_mask)
elif attention_direction == "cross":
left_features = layer(left_features, right_features, left_mask, right_mask)
right_features = layer(right_features, left_features, right_mask, left_mask)
return left_features, right_features
class PyramidDownsample(nn.Module):
"""
A simple wrapper that return and Avg Pool feature pyramid based on the provided scales.
Implicitly returns the input as well.
"""
def __init__(self, factors: Iterable[int]) -> None:
super().__init__()
self.factors = factors
def forward(self, x: torch.Tensor) -> List[Tensor]:
results = [x]
for factor in self.factors:
results.append(F.avg_pool2d(x, kernel_size=factor, stride=factor))
return results
class CREStereo(nn.Module):
"""
Implements CREStereo from the `"Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation" <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_ paper.
Args:
feature_encoder (raft.FeatureEncoder): Raft-like Feature Encoder module extract low-level features from inputs.
update_block (raft.UpdateBlock): Raft-like Update Block which recursively refines a flow-map.
flow_head (raft.FlowHead): Raft-like Flow Head which predics a flow-map from some inputs.
self_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs self attention on the two feature maps.
cross_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs cross attention between the two feature maps
used in the Adaptive Group Correlation module.
feature_downsample_rates (List[int]): The downsample rates used to build a feature pyramid from the outputs of the `feature_encoder`. Default: [2, 4]
correlation_groups (int): In how many groups should the features be split when computer per-pixel correlation. Defaults 4.
search_window_1d (Tuple[int, int]): The alternate search window size in the x and y directions for the 1D case. Defaults to (1, 9).
search_dilate_1d (Tuple[int, int]): The dilation used in the `search_window_1d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1).
search_window_2d (Tuple[int, int]): The alternate search window size in the x and y directions for the 2D case. Defaults to (3, 3).
search_dilate_2d (Tuple[int, int]): The dilation used in the `search_window_2d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1).
"""
def __init__(
self,
*,
feature_encoder: raft.FeatureEncoder,
update_block: raft.UpdateBlock,
flow_head: raft.FlowHead,
self_attn_block: LocalFeatureTransformer,
cross_attn_block: LocalFeatureTransformer,
feature_downsample_rates: Tuple[int, ...] = (2, 4),
correlation_groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
self.output_channels = 2
self.feature_encoder = feature_encoder
self.update_block = update_block
self.flow_head = flow_head
self.self_attn_block = self_attn_block
# average pooling for the feature encoder outputs
self.downsampling_pyramid = PyramidDownsample(feature_downsample_rates)
self.downsampling_factors: List[int] = [feature_encoder.downsample_factor]
base_downsample_factor: int = self.downsampling_factors[0]
for rate in feature_downsample_rates:
self.downsampling_factors.append(base_downsample_factor * rate)
# output resolution tracking
self.resolutions: List[str] = [f"1 / {factor}" for factor in self.downsampling_factors]
self.search_pixels = int(np.prod(search_window_1d))
# flow convex upsampling mask predictor
self.mask_predictor = ConvexMaskPredictor(
in_channels=feature_encoder.output_dim // 2,
hidden_size=feature_encoder.output_dim,
upsample_factor=feature_encoder.downsample_factor,
multiplier=0.25,
)
# offsets modules for offsetted feature selection
self.offset_convs = nn.ModuleDict()
self.correlation_layers = nn.ModuleDict()
offset_conv_layer = partial(
Conv2dNormActivation,
in_channels=feature_encoder.output_dim,
out_channels=self.search_pixels * 2,
norm_layer=None,
activation_layer=None,
)
# populate the dicts in top to bottom order
# useful for iterating through torch.jit.script module given the network forward pass
#
# Ignore the largest resolution. We handle that separately due to torch.jit.script
# not being able to access to runtime generated keys in ModuleDicts.
# This way, we can keep a generic way of processing all pyramid levels but except
# the final one
iterative_correlation_layer = partial(
IterativeCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
attention_offset_correlation_layer = partial(
AttentionOffsetCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
for idx, resolution in enumerate(reversed(self.resolutions[1:])):
# the largest resolution does use offset convolutions for sampling grid coords
offset_conv = None if idx == len(self.resolutions) - 1 else offset_conv_layer()
if offset_conv:
self.offset_convs[resolution] = offset_conv
# only the lowest resolution uses the cross attention module when computing correlation scores
attention_module = cross_attn_block if idx == 0 else None
self.correlation_layers[resolution] = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(
attention_module=attention_module
),
)
# correlation layer for the largest resolution
self.max_res_correlation_layer = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(),
)
# simple 2D Postional Encodings
self.positional_encodings = PositionalEncodingSine(feature_encoder.output_dim)
def _get_window_type(self, iteration: int) -> str:
return "1d" if iteration % 2 == 0 else "2d"
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features)
left_features, right_features = features.chunk(2, dim=0)
# update block network state and input context are derived from the left feature map
net, ctx = left_features.chunk(2, dim=1)
net = torch.tanh(net)
ctx = torch.relu(ctx)
# will output lists of tensor.
l_pyramid = self.downsampling_pyramid(left_features)
r_pyramid = self.downsampling_pyramid(right_features)
net_pyramid = self.downsampling_pyramid(net)
ctx_pyramid = self.downsampling_pyramid(ctx)
# we store in reversed order because we process the pyramid from top to bottom
l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
# offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {}
for resolution, offset_conv in self.offset_convs.items():
feature_map = l_pyramid[resolution]
offset = offset_conv(feature_map)
offsets[resolution] = (torch.sigmoid(offset) - 0.5) * 2.0
# the smallest resolution is prepared for passing through self attention
min_res = self.resolutions[-1]
max_res = self.resolutions[0]
B, C, MIN_H, MIN_W = l_pyramid[min_res].shape
# add positional encodings
l_pyramid[min_res] = self.positional_encodings(l_pyramid[min_res])
r_pyramid[min_res] = self.positional_encodings(r_pyramid[min_res])
# reshaping for transformer
l_pyramid[min_res] = l_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
r_pyramid[min_res] = r_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
# perform self attention
l_pyramid[min_res], r_pyramid[min_res] = self.self_attn_block(l_pyramid[min_res], r_pyramid[min_res])
# now we need to reshape back into [B, C, H, W] format
l_pyramid[min_res] = l_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
r_pyramid[min_res] = r_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
predictions: List[Tensor] = []
flow_estimates: Dict[str, Tensor] = {}
# we added this because of torch.script.jit
# also, the predicition prior is always going to have the
# spatial size of the features outputted by the feature encoder
flow_pred_prior: Tensor = torch.empty(
size=(B, 2, left_features.shape[2], left_features.shape[3]),
dtype=l_pyramid[max_res].dtype,
device=l_pyramid[max_res].device,
)
if flow_init is not None:
scale = l_pyramid[max_res].shape[2] / flow_init.shape[2]
# in CREStereo implementation they multiply with -scale instead of scale
# this can be either a downsample or an upsample based on the cascaded inference
# configuration
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction
flow_estimates[max_res] = -scale * F.interpolate(
input=flow_init,
size=l_pyramid[max_res].shape[2:],
mode="bilinear",
align_corners=True,
)
# when not provided with a flow prior, we construct one using the lower resolution maps
else:
# initialize a zero flow with the smallest resolution
flow = torch.zeros(size=(B, 2, MIN_H, MIN_W), device=left_features.device, dtype=left_features.dtype)
# flows from coarse resolutions are refined similarly
# we always need to fetch the next pyramid feature map as well
# when updating coarse resolutions, therefore we create a reversed
# view which has its order synced with the ModuleDict keys iterator
coarse_resolutions: List[str] = self.resolutions[::-1] # using slicing because of torch.jit.script
fine_grained_resolution = max_res
# set the coarsest flow to the zero flow
flow_estimates[coarse_resolutions[0]] = flow
# the correlation layer ModuleDict will contain layers ordered from coarse to fine resolution
# i.e ["1 / 16", "1 / 8", "1 / 4"]
# the correlation layer ModuleDict has layers for all the resolutions except the fine one
# i.e {"1 / 16": Module, "1 / 8": Module}
# for these resolution we perform only half of the number of refinement iterations
for idx, (resolution, correlation_layer) in enumerate(self.correlation_layers.items()):
# compute the scale difference between the first pyramid scale and the current pyramid scale
scale_to_base = l_pyramid[fine_grained_resolution].shape[2] // l_pyramid[resolution].shape[2]
for it in range(num_iters // 2):
# set whether we want to search on (X, Y) axes for correlation or just on X axis
window_type = self._get_window_type(it)
# we consider this a prior, therefore we do not want to back-propagate through it
flow_estimates[resolution] = flow_estimates[resolution].detach()
correlations = correlation_layer(
l_pyramid[resolution], # left
r_pyramid[resolution], # right
flow_estimates[resolution],
offsets[resolution],
window_type,
)
# update the recurrent network state and the flow deltas
net_pyramid[resolution], delta_flow = self.update_block(
net_pyramid[resolution], ctx_pyramid[resolution], correlations, flow_estimates[resolution]
)
# the convex upsampling weights are computed w.r.t.
# the recurrent update state
up_mask = self.mask_predictor(net_pyramid[resolution])
flow_estimates[resolution] = flow_estimates[resolution] + delta_flow
# convex upsampling with the initial feature encoder downsampling rate
flow_pred_prior = upsample_flow(
flow_estimates[resolution], up_mask, factor=self.downsampling_factors[0]
)
# we then bilinear upsample to the final resolution
# we use a factor that's equivalent to the difference between
# the current downsample resolution and the base downsample resolution
#
# i.e. if a 1 / 16 flow is upsampled by 4 (base downsampling) we get a 1 / 4 flow.
# therefore we have to further upscale it by the difference between
# the current level 1 / 16 and the base level 1 / 4.
#
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction in order to get the
# left to right flow
flow_pred = -upsample_flow(flow_pred_prior, None, factor=scale_to_base)
predictions.append(flow_pred)
# when constructing the next resolution prior, we resample w.r.t
# to the scale of the next level in the pyramid
next_resolution = coarse_resolutions[idx + 1]
scale_to_next = l_pyramid[next_resolution].shape[2] / flow_pred_prior.shape[2]
# we use the flow_up_prior because this is a more accurate estimation of the true flow
# due to the convex upsample, which resembles a learned super-resolution module.
# this is not necessarily an upsample, it can be a downsample, based on the provided configuration
flow_estimates[next_resolution] = -scale_to_next * F.interpolate(
input=flow_pred_prior,
size=l_pyramid[next_resolution].shape[2:],
mode="bilinear",
align_corners=True,
)
# finally we will be doing a full pass through the fine-grained resolution
# this coincides with the maximum resolution
# we keep a separate loop here in order to avoid python control flow
# to decide how many iterations should we do based on the current resolution
# furthermore, if provided with an initial flow, there is no need to generate
# a prior estimate when moving into the final refinement stage
for it in range(num_iters):
search_window_type = self._get_window_type(it)
flow_estimates[max_res] = flow_estimates[max_res].detach()
# we run the fine-grained resolution correlations in iterative mode
# this means that we are using the fixed window pixel selections
# instead of the deformed ones as with the previous steps
correlations = self.max_res_correlation_layer(
l_pyramid[max_res],
r_pyramid[max_res],
flow_estimates[max_res],
extra_offset=None,
window_type=search_window_type,
iter_mode=True,
)
net_pyramid[max_res], delta_flow = self.update_block(
net_pyramid[max_res], ctx_pyramid[max_res], correlations, flow_estimates[max_res]
)
up_mask = self.mask_predictor(net_pyramid[max_res])
flow_estimates[max_res] = flow_estimates[max_res] + delta_flow
# at the final resolution we simply do a convex upsample using the base downsample rate
flow_pred = -upsample_flow(flow_estimates[max_res], up_mask, factor=self.downsampling_factors[0])
predictions.append(flow_pred)
return predictions
def _crestereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
# Feature Encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
feature_encoder_norm_layer: Callable[..., nn.Module],
# Average Pooling Pyramid
feature_downsample_rates: Tuple[int, ...],
# Adaptive Correlation Layer
corr_groups: int,
corr_search_window_2d: Tuple[int, int],
corr_search_dilate_2d: Tuple[int, int],
corr_search_window_1d: Tuple[int, int],
corr_search_dilate_1d: Tuple[int, int],
# Flow head
flow_head_hidden_size: int,
# Recurrent block
recurrent_block_hidden_state_size: int,
recurrent_block_kernel_size: Tuple[Tuple[int, int], Tuple[int, int]],
recurrent_block_padding: Tuple[Tuple[int, int], Tuple[int, int]],
# Motion Encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Transformer Blocks
num_attention_heads: int,
num_self_attention_layers: int,
num_cross_attention_layers: int,
self_attention_module: Callable[..., nn.Module],
cross_attention_module: Callable[..., nn.Module],
**kwargs,
) -> CREStereo:
feature_encoder = kwargs.pop("feature_encoder", None) or raft.FeatureEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers,
strides=feature_encoder_strides,
norm_layer=feature_encoder_norm_layer,
)
if feature_encoder.output_dim % corr_groups != 0:
raise ValueError(
f"Final ``feature_encoder_layers`` size should be divisible by ``corr_groups`` argument."
f"Feature encoder output size : {feature_encoder.output_dim}, Correlation groups: {corr_groups}."
)
motion_encoder = kwargs.pop("motion_encoder", None) or raft.MotionEncoder(
in_channels_corr=corr_groups * int(np.prod(corr_search_window_1d)),
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
out_channels_context = feature_encoder_layers[-1] - recurrent_block_hidden_state_size
recurrent_block = kwargs.pop("recurrent_block", None) or raft.RecurrentBlock(
input_size=motion_encoder.out_channels + out_channels_context,
hidden_size=recurrent_block_hidden_state_size,
kernel_size=recurrent_block_kernel_size,
padding=recurrent_block_padding,
)
flow_head = kwargs.pop("flow_head", None) or raft.FlowHead(
in_channels=out_channels_context, hidden_size=flow_head_hidden_size
)
update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
self_attention_module = kwargs.pop("self_attention_module", None) or LinearAttention
self_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["self"] * num_self_attention_layers,
attention_module=self_attention_module,
)
cross_attention_module = kwargs.pop("cross_attention_module", None) or LinearAttention
cross_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["cross"] * num_cross_attention_layers,
attention_module=cross_attention_module,
)
model = CREStereo(
feature_encoder=feature_encoder,
update_block=update_block,
flow_head=flow_head,
self_attn_block=self_attn_block,
cross_attn_block=cross_attn_block,
feature_downsample_rates=feature_downsample_rates,
correlation_groups=corr_groups,
search_window_1d=corr_search_window_1d,
search_window_2d=corr_search_window_2d,
search_dilate_1d=corr_search_dilate_1d,
search_dilate_2d=corr_search_dilate_2d,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
_COMMON_META = {
"resize_size": (384, 512),
}
class CREStereo_Base_Weights(WeightsEnum):
"""The metrics reported here are as follows.
``mae`` is the "mean-average-error" and indicates how far (in pixels) the
predicted disparity is from its true value (equivalent to ``epe``). This is averaged over all pixels
of all images. ``1px``, ``3px``, ``5px`` and indicate the percentage of pixels that have a lower
error than that of the ground truth. ``relepe`` is the "relative-end-point-error" and is the
average ``epe`` divided by the average ground truth disparity. ``fl-all`` corresponds to the average of pixels whose epe
is either <3px, or whom's ``relepe`` is lower than 0.05 (therefore higher is better).
"""
MEGVII_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-756c8b0f.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/megvii-research/CREStereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement iterations
2: {
"mae": 1.704,
"rmse": 3.738,
"1px": 0.738,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.157,
"fl-all": 76.464,
},
5: {
"mae": 0.956,
"rmse": 2.963,
"1px": 0.88,
"3px": 0.948,
"5px": 0.965,
"relepe": 0.124,
"fl-all": 88.186,
},
10: {
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
},
20: {
"mae": 0.749,
"rmse": 2.706,
"1px": 0.907,
"3px": 0.961,
"5px": 0.972,
"relepe": 0.113,
"fl-all": 90.807,
},
},
2: {
2: {
"mae": 1.702,
"rmse": 3.784,
"1px": 0.784,
"3px": 0.894,
"5px": 0.924,
"relepe": 0.172,
"fl-all": 80.313,
},
5: {
"mae": 0.932,
"rmse": 2.907,
"1px": 0.877,
"3px": 0.944,
"5px": 0.963,
"relepe": 0.125,
"fl-all": 87.979,
},
10: {
"mae": 0.773,
"rmse": 2.768,
"1px": 0.901,
"3px": 0.958,
"5px": 0.972,
"relepe": 0.117,
"fl-all": 90.43,
},
20: {
"mae": 0.854,
"rmse": 2.971,
"1px": 0.9,
"3px": 0.957,
"5px": 0.97,
"relepe": 0.122,
"fl-all": 90.269,
},
},
},
}
},
"_docs": """These weights were ported from the original paper. They
are trained on a dataset mixture of the author's choice.""",
},
)
CRESTEREO_ETH_MBL_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-8f0e0e9a.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is the number of refinement iterations
2: {
"mae": 2.363,
"rmse": 4.352,
"1px": 0.611,
"3px": 0.828,
"5px": 0.891,
"relepe": 0.176,
"fl-all": 64.511,
},
5: {
"mae": 1.618,
"rmse": 3.71,
"1px": 0.761,
"3px": 0.879,
"5px": 0.918,
"relepe": 0.154,
"fl-all": 77.128,
},
10: {
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
},
20: {
"mae": 1.448,
"rmse": 3.583,
"1px": 0.771,
"3px": 0.893,
"5px": 0.931,
"relepe": 0.145,
"fl-all": 77.7,
},
},
2: {
2: {
"mae": 1.972,
"rmse": 4.125,
"1px": 0.73,
"3px": 0.865,
"5px": 0.908,
"relepe": 0.169,
"fl-all": 74.396,
},
5: {
"mae": 1.403,
"rmse": 3.448,
"1px": 0.793,
"3px": 0.905,
"5px": 0.937,
"relepe": 0.151,
"fl-all": 80.186,
},
10: {
"mae": 1.312,
"rmse": 3.368,
"1px": 0.799,
"3px": 0.912,
"5px": 0.943,
"relepe": 0.148,
"fl-all": 80.379,
},
20: {
"mae": 1.376,
"rmse": 3.542,
"1px": 0.796,
"3px": 0.91,
"5px": 0.942,
"relepe": 0.149,
"fl-all": 80.054,
},
},
},
}
},
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo`.""",
},
)
CRESTEREO_FINETUNE_MULTI_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-697c38f4.pth ",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.038,
"rmse": 3.108,
"1px": 0.852,
"3px": 0.942,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.522,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement iterations
2: {
"mae": 1.85,
"rmse": 3.797,
"1px": 0.673,
"3px": 0.862,
"5px": 0.917,
"relepe": 0.171,
"fl-all": 69.736,
},
5: {
"mae": 1.111,
"rmse": 3.166,
"1px": 0.838,
"3px": 0.93,
"5px": 0.957,
"relepe": 0.134,
"fl-all": 84.596,
},
10: {
"mae": 1.02,
"rmse": 3.073,
"1px": 0.854,
"3px": 0.938,
"5px": 0.96,
"relepe": 0.129,
"fl-all": 86.042,
},
20: {
"mae": 0.993,
"rmse": 3.059,
"1px": 0.855,
"3px": 0.942,
"5px": 0.967,
"relepe": 0.126,
"fl-all": 85.784,
},
},
2: {
2: {
"mae": 1.667,
"rmse": 3.867,
"1px": 0.78,
"3px": 0.891,
"5px": 0.922,
"relepe": 0.165,
"fl-all": 78.89,
},
5: {
"mae": 1.158,
"rmse": 3.278,
"1px": 0.843,
"3px": 0.926,
"5px": 0.955,
"relepe": 0.135,
"fl-all": 84.556,
},
10: {
"mae": 1.046,
"rmse": 3.13,
"1px": 0.85,
"3px": 0.934,
"5px": 0.96,
"relepe": 0.13,
"fl-all": 85.464,
},
20: {
"mae": 1.021,
"rmse": 3.102,
"1px": 0.85,
"3px": 0.935,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.417,
},
},
},
},
},
"_docs": """These weights were finetuned on a mixture of
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo` +
:class:`~torchvision.datasets._stereo_matching.InStereo2k` +
:class:`~torchvision.datasets._stereo_matching.CarlaStereo` +
:class:`~torchvision.datasets._stereo_matching.SintelStereo` +
:class:`~torchvision.datasets._stereo_matching.FallingThingsStereo` +
.""",
},
)
DEFAULT = MEGVII_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", CREStereo_Base_Weights.MEGVII_V1))
def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress=True, **kwargs) -> CREStereo:
"""CREStereo model from
`Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/crestereo.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""
weights = CREStereo_Base_Weights.verify(weights)
return _crestereo(
weights=weights,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 1),
feature_encoder_block=partial(raft.ResidualBlock, always_project=True),
feature_encoder_norm_layer=nn.InstanceNorm2d,
# Average pooling pyramid
feature_downsample_rates=(2, 4),
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Transformer blocks
num_attention_heads=8,
num_self_attention_layers=1,
num_cross_attention_layers=1,
self_attention_module=LinearAttention,
cross_attention_module=LinearAttention,
# Adaptive Correlation layer
corr_groups=4,
corr_search_window_2d=(3, 3),
corr_search_dilate_2d=(1, 1),
corr_search_window_1d=(1, 9),
corr_search_dilate_1d=(1, 1),
)
from functools import partial
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
from torchvision.utils import _log_api_usage_once
__all__ = (
"RaftStereo",
"raft_stereo_base",
"raft_stereo_realtime",
"Raft_Stereo_Base_Weights",
"Raft_Stereo_Realtime_Weights",
)
class BaseEncoder(raft.FeatureEncoder):
"""Base encoder for FeatureEncoder and ContextEncoder in which weight may be shared.
See the Raft-Stereo paper section 4.6 on backbone part.
"""
def __init__(
self,
*,
block: Callable[..., nn.Module] = ResidualBlock,
layers: Tuple[int, int, int, int] = (64, 64, 96, 128),
strides: Tuple[int, int, int, int] = (2, 1, 2, 2),
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d,
):
# We use layers + (256,) because raft.FeatureEncoder require 5 layers
# but here we will set the last conv layer to identity
super().__init__(block=block, layers=layers + (256,), strides=strides, norm_layer=norm_layer)
# Base encoder don't have the last conv layer of feature encoder
self.conv = nn.Identity()
self.output_dim = layers[3]
num_downsampling = sum([x - 1 for x in strides])
self.downsampling_ratio = 2 ** (num_downsampling)
class FeatureEncoder(nn.Module):
"""Feature Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Context Encoder.
The FeatureEncoder takes concatenation of left and right image as input. It produces feature embedding that later
will be used to construct correlation volume.
"""
def __init__(
self,
base_encoder: BaseEncoder,
output_dim: int = 256,
shared_base: bool = False,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
if not shared_base:
self.residual_block: nn.Module = nn.Identity()
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=1)
else:
# If we share base encoder weight for Feature and Context Encoder
# we need to add residual block with InstanceNorm2d and change the kernel size for conv layer
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L35-L37
self.residual_block = block(base_dim, base_dim, norm_layer=nn.InstanceNorm2d, stride=1)
self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.base_encoder(x)
x = self.residual_block(x)
x = self.conv(x)
return x
class MultiLevelContextEncoder(nn.Module):
"""Context Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Feature Encoder.
The ContextEncoder takes left image as input, and it outputs concatenated hidden_states and contexts.
In Raft-Stereo we have multi level GRUs and this context encoder will also multi outputs (list of Tensor)
that correspond to each GRUs.
Take note that the length of "out_with_blocks" parameter represent the number of GRU's level.
args:
base_encoder (nn.Module): The base encoder part that can have a shared weight with feature_encoder's
base_encoder because they have same architecture.
out_with_blocks (List[bool]): The length represent the number of GRU's level (length of output), and
if the element is True then the output layer on that position will have additional block
output_dim (int): The dimension of output on each level (default: 256)
block (Callable[..., nn.Module]): The type of basic block used for downsampling and output layer
(default: ResidualBlock)
"""
def __init__(
self,
base_encoder: nn.Module,
out_with_blocks: List[bool],
output_dim: int = 256,
block: Callable[..., nn.Module] = ResidualBlock,
):
super().__init__()
self.num_level = len(out_with_blocks)
self.base_encoder = base_encoder
self.base_downsampling_ratio = base_encoder.downsampling_ratio
base_dim = base_encoder.output_dim
self.downsample_and_out_layers = nn.ModuleList(
[
nn.ModuleDict(
{
"downsampler": self._make_downsampler(block, base_dim, base_dim) if i > 0 else nn.Identity(),
"out_hidden_state": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
"out_context": self._make_out_layer(
base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block
),
}
)
for i in range(self.num_level)
]
)
def _make_out_layer(self, in_channels, out_channels, with_block=True, block=ResidualBlock):
layers = []
if with_block:
layers.append(block(in_channels, in_channels, norm_layer=nn.BatchNorm2d, stride=1))
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
return nn.Sequential(*layers)
def _make_downsampler(self, block, in_channels, out_channels):
block1 = block(in_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=2)
block2 = block(out_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=1)
return nn.Sequential(block1, block2)
def forward(self, x: Tensor) -> List[Tensor]:
x = self.base_encoder(x)
outs = []
for layer_dict in self.downsample_and_out_layers:
x = layer_dict["downsampler"](x)
outs.append(torch.cat([layer_dict["out_hidden_state"](x), layer_dict["out_context"](x)], dim=1))
return outs
class ConvGRU(raft.ConvGRU):
"""Convolutional Gru unit."""
# Modified from raft.ConvGRU to accept pre-convolved contexts,
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/update.py#L23
def forward(self, h: Tensor, x: Tensor, context: List[Tensor]) -> Tensor: # type: ignore[override]
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) + context[0])
r = torch.sigmoid(self.convr(hx) + context[1])
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)) + context[2])
h = (1 - z) * h + z * q
return h
class MultiLevelUpdateBlock(nn.Module):
"""The update block which contains the motion encoder and grus
It must expose a ``hidden_dims`` attribute which is the hidden dimension size of its gru blocks
"""
def __init__(self, *, motion_encoder: MotionEncoder, hidden_dims: List[int]):
super().__init__()
self.motion_encoder = motion_encoder
# The GRU input size is the size of previous level hidden_dim plus next level hidden_dim
# if this is the first gru, then we replace previous level with motion_encoder output channels
# for the last GRU, we don't add the next level hidden_dim
gru_input_dims = []
for i in range(len(hidden_dims)):
input_dim = hidden_dims[i - 1] if i > 0 else motion_encoder.out_channels
if i < len(hidden_dims) - 1:
input_dim += hidden_dims[i + 1]
gru_input_dims.append(input_dim)
self.grus = nn.ModuleList(
[
ConvGRU(input_size=gru_input_dims[i], hidden_size=hidden_dims[i], kernel_size=3, padding=1)
# Ideally we should reverse the direction during forward to use the gru with the smallest resolution
# first however currently there is no way to reverse a ModuleList that is jit script compatible
# hence we reverse the ordering of self.grus on the constructor instead
# see: https://github.com/pytorch/pytorch/issues/31772
for i in reversed(list(range(len(hidden_dims))))
]
)
self.hidden_dims = hidden_dims
def forward(
self,
hidden_states: List[Tensor],
contexts: List[List[Tensor]],
corr_features: Tensor,
disparity: Tensor,
level_processed: List[bool],
) -> List[Tensor]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states
# see self.grus on the constructor for more detail
for reverse_i, gru in enumerate(self.grus):
i = len(self.grus) - 1 - reverse_i
if level_processed[i]:
# X is concatenation of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist).
if i == 0:
features = self.motion_encoder(disparity, corr_features)
else:
# 2x downsampled features from larger hidden states
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
if i < len(self.grus) - 1:
# Concat with 2x upsampled features from smaller hidden states
_, _, h, w = hidden_states[i + 1].shape
features = torch.cat(
[
features,
F.interpolate(
hidden_states[i + 1], size=(2 * h, 2 * w), mode="bilinear", align_corners=True
),
],
dim=1,
)
hidden_states[i] = gru(hidden_states[i], features, contexts[i])
# NOTE: For slow-fast gru, we don't always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward
return hidden_states
class MaskPredictor(raft.MaskPredictor):
"""Mask predictor to be used when upsampling the predicted disparity."""
# We add out_channels compared to raft.MaskPredictor
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
super(raft.MaskPredictor, self).__init__()
self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
self.conv = nn.Conv2d(hidden_size, out_channels, kernel_size=1, padding=0)
self.multiplier = multiplier
class CorrPyramid1d(nn.Module):
"""Row-wise correlation pyramid.
Create a row-wise correlation pyramid with ``num_levels`` level from the outputs of the feature encoder,
this correlation pyramid will later be used as index to create correlation features using CorrBlock1d.
"""
def __init__(self, num_levels: int = 4):
super().__init__()
self.num_levels = num_levels
def forward(self, fmap1: Tensor, fmap2: Tensor) -> List[Tensor]:
"""Build the correlation pyramid from two feature maps.
The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row.
The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
to build the correlation pyramid.
"""
torch._assert(
fmap1.shape == fmap2.shape,
f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)",
)
batch_size, num_channels, h, w = fmap1.shape
fmap1 = fmap1.view(batch_size, num_channels, h, w)
fmap2 = fmap2.view(batch_size, num_channels, h, w)
corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
corr = corr.view(batch_size, h, w, 1, w)
corr_volume = corr / torch.sqrt(torch.tensor(num_channels, device=corr.device))
corr_volume = corr_volume.reshape(batch_size * h * w, 1, 1, w)
corr_pyramid = [corr_volume]
for _ in range(self.num_levels - 1):
corr_volume = F.avg_pool2d(corr_volume, kernel_size=(1, 2), stride=(1, 2))
corr_pyramid.append(corr_volume)
return corr_pyramid
class CorrBlock1d(nn.Module):
"""The row-wise correlation block.
Use indexes from correlation pyramid to create correlation features.
The "indexing" of a given centroid pixel x' is done by concatenating its surrounding row neighbours
within radius
"""
def __init__(self, *, num_levels: int = 4, radius: int = 4):
super().__init__()
self.radius = radius
self.out_channels = num_levels * (2 * radius + 1)
def forward(self, centroids_coords: Tensor, corr_pyramid: List[Tensor]) -> Tensor:
"""Return correlation features by indexing from the pyramid."""
neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, device=centroids_coords.device)
di = di.view(1, 1, neighborhood_side_len, 1).to(centroids_coords.device)
batch_size, _, h, w = centroids_coords.shape # _ = 2 but we only use the first one
# We only consider 1d and take the first dim only
centroids_coords = centroids_coords[:, :1].permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 1)
indexed_pyramid = []
for corr_volume in corr_pyramid:
x0 = centroids_coords + di # end shape is (batch_size * h * w, 1, side_len, 1)
y0 = torch.zeros_like(x0)
sampling_coords = torch.cat([x0, y0], dim=-1)
indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
batch_size, h, w, -1
)
indexed_pyramid.append(indexed_corr_volume)
centroids_coords = centroids_coords / 2
corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
expected_output_shape = (batch_size, self.out_channels, h, w)
torch._assert(
corr_features.shape == expected_output_shape,
f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}",
)
return corr_features
class RaftStereo(nn.Module):
def __init__(
self,
*,
feature_encoder: FeatureEncoder,
context_encoder: MultiLevelContextEncoder,
corr_pyramid: CorrPyramid1d,
corr_block: CorrBlock1d,
update_block: MultiLevelUpdateBlock,
disparity_head: nn.Module,
mask_predictor: Optional[nn.Module] = None,
slow_fast: bool = False,
):
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts:
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
- one part will be used to initialize the hidden state of the recurrent unit of
the ``update_block``
corr_pyramid (CorrPyramid1d): Module to build the correlation pyramid from feature encoder output
corr_block (CorrBlock1d): The correlation block, which uses the correlation pyramid indexes
to create correlation features. It takes the coordinate of the centroid pixel and correlation pyramid
as input and returns the correlation features.
It must expose an ``out_channels`` attribute.
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted disparity. It outputs an updated hidden state
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
on section 3.4 for more detail.
"""
super().__init__()
_log_api_usage_once(self)
# This indicates that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self.output_channels = 1
self.feature_encoder = feature_encoder
self.context_encoder = context_encoder
self.base_downsampling_ratio = feature_encoder.base_downsampling_ratio
self.num_level = self.context_encoder.num_level
self.corr_pyramid = corr_pyramid
self.corr_block = corr_block
self.update_block = update_block
self.disparity_head = disparity_head
self.mask_predictor = mask_predictor
hidden_dims = self.update_block.hidden_dims
# Follow the original implementation to do pre convolution on the context
# See: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L32
self.context_convs = nn.ModuleList(
[nn.Conv2d(hidden_dims[i], hidden_dims[i] * 3, kernel_size=3, padding=1) for i in range(self.num_level)]
)
self.slow_fast = slow_fast
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
) -> List[Tensor]:
"""
Return disparity predictions on every iteration as a list of Tensor.
args:
left_image (Tensor): The input left image with layout B, C, H, W
right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
"""
batch_size, _, h, w = left_image.shape
torch._assert(
(h, w) == right_image.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
)
torch._assert(
(h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0),
f"input image H and W should be divisible by {self.base_downsampling_ratio}, instead got H={h} and W={w}",
)
fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
torch._assert(
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
f"The feature encoder should downsample H and W by {self.base_downsampling_ratio}",
)
corr_pyramid = self.corr_pyramid(fmap1, fmap2)
# Multi level contexts
context_outs = self.context_encoder(left_image)
hidden_dims = self.update_block.hidden_dims
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
hidden_states: List[Tensor] = []
contexts: List[List[Tensor]] = []
for i, context_conv in enumerate(self.context_convs):
# As in the original paper, the actual output of the context encoder is split in 2 parts:
# - one part is used to initialize the hidden state of the recurent units of the update block
# - the rest is the "actual" context.
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
hidden_states.append(torch.tanh(hidden_state))
contexts.append(
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
# JIT and thus we have to keep the wrong annotation here and silence mypy.
torch.split( # type: ignore[arg-type]
context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1
)
)
_, Cf, Hf, Wf = fmap1.shape
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
# We use flow_init for cascade inference
if flow_init is not None:
coords1 = coords1 + flow_init
disparity_predictions = []
for _ in range(num_iters):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)
disparity = coords1 - coords0
if self.slow_fast:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for i in range(1, self.num_level):
# We only processed the smallest i levels
level_processed = [False] * (self.num_level - i) + [True] * i
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=level_processed
)
hidden_states = self.update_block(
hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
)
# Take the largest hidden_state to get the disparity
hidden_state = hidden_states[0]
delta_disparity = self.disparity_head(hidden_state)
# in stereo mode, project disparity onto epipolar
delta_disparity[:, 1] = 0.0
coords1 = coords1 + delta_disparity
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_disparity = upsample_flow(
(coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
)
disparity_predictions.append(upsampled_disparity[:, :1])
return disparity_predictions
def _raft_stereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
shared_encoder_weight: bool,
# Feature encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
# Context encoder
context_encoder_layers: Tuple[int, int, int, int, int],
context_encoder_strides: Tuple[int, int, int, int],
# if the `out_with_blocks` param of the context_encoder is True, then
# the particular output on that level position will have additional `context_encoder_block` layer
context_encoder_out_with_blocks: List[bool],
context_encoder_block: Callable[..., nn.Module],
# Correlation block
corr_num_levels: int,
corr_radius: int,
# Motion encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Update block
update_block_hidden_dims: List[int],
# Flow Head
flow_head_hidden_size: int,
# Mask predictor
mask_predictor_hidden_size: int,
use_mask_predictor: bool,
slow_fast: bool,
**kwargs,
):
if len(context_encoder_out_with_blocks) != len(update_block_hidden_dims):
raise ValueError(
"Length of context_encoder_out_with_blocks and update_block_hidden_dims must be the same"
+ "because both of them represent the number of GRUs level"
)
if shared_encoder_weight:
if (
feature_encoder_layers[:-1] != context_encoder_layers[:-1]
or feature_encoder_strides != context_encoder_strides
):
raise ValueError(
"If shared_encoder_weight is True, then the feature_encoder_layers[:-1]"
+ " and feature_encoder_strides must be the same with context_encoder_layers[:-1] and context_encoder_strides!"
)
base_encoder = kwargs.pop("base_encoder", None) or BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_base_encoder = base_encoder
context_base_encoder = base_encoder
else:
feature_base_encoder = BaseEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers[:-1],
strides=feature_encoder_strides,
norm_layer=nn.InstanceNorm2d,
)
context_base_encoder = BaseEncoder(
block=context_encoder_block,
layers=context_encoder_layers[:-1],
strides=context_encoder_strides,
norm_layer=nn.BatchNorm2d,
)
feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
feature_base_encoder,
output_dim=feature_encoder_layers[-1],
shared_base=shared_encoder_weight,
block=feature_encoder_block,
)
context_encoder = kwargs.pop("context_encoder", None) or MultiLevelContextEncoder(
context_base_encoder,
out_with_blocks=context_encoder_out_with_blocks,
output_dim=context_encoder_layers[-1],
block=context_encoder_block,
)
feature_downsampling_ratio = feature_encoder.base_downsampling_ratio
corr_pyramid = kwargs.pop("corr_pyramid", None) or CorrPyramid1d(num_levels=corr_num_levels)
corr_block = kwargs.pop("corr_block", None) or CorrBlock1d(num_levels=corr_num_levels, radius=corr_radius)
motion_encoder = kwargs.pop("motion_encoder", None) or MotionEncoder(
in_channels_corr=corr_block.out_channels,
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
update_block = kwargs.pop("update_block", None) or MultiLevelUpdateBlock(
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
)
# We use the largest scale hidden_dims of update_block to get the predicted disparity
disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
in_channels=update_block_hidden_dims[0],
hidden_size=flow_head_hidden_size,
)
mask_predictor = kwargs.pop("mask_predictor", None)
if use_mask_predictor:
mask_predictor = MaskPredictor(
in_channels=update_block.hidden_dims[0],
hidden_size=mask_predictor_hidden_size,
out_channels=9 * feature_downsampling_ratio * feature_downsampling_ratio,
)
else:
mask_predictor = None
model = RaftStereo(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_pyramid=corr_pyramid,
corr_block=corr_block,
update_block=update_block,
disparity_head=disparity_head,
mask_predictor=mask_predictor,
slow_fast=slow_fast,
**kwargs, # not really needed, all params should be consumed by now
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
class Raft_Stereo_Realtime_Weights(WeightsEnum):
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 8077152,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Kitty2015": {
"3px": 0.9409,
}
},
},
)
DEFAULT = SCENEFLOW_V1
class Raft_Stereo_Base_Weights(WeightsEnum):
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
# Using standard metrics for each dataset
"Kitty2015": {
# Ratio of pixels with difference less than 3px from ground truth
"3px": 0.9426,
},
# For middlebury, ratio of pixels with difference less than 2px from ground truth
# on full, half, and quarter image resolution
"Middlebury2014-val-full": {
"2px": 0.8167,
},
"Middlebury2014-val-half": {
"2px": 0.8741,
},
"Middlebury2014-val-quarter": {
"2px": 0.9064,
},
"ETH3D-val": {
# Ratio of pixels with difference less than 1px from ground truth
"1px": 0.9672,
},
},
},
)
MIDDLEBURY_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Middlebury-test": {
"mae": 1.27,
"1px": 0.9063,
"2px": 0.9526,
"5px": 0.9725,
}
},
},
)
ETH3D_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"ETH3D-test": {
"mae": 0.18,
"1px": 0.9756,
"2px": 0.9956,
}
},
},
)
DEFAULT = MIDDLEBURY_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_realtime(
*, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs
) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
This is the realtime variant of the Raft-Stereo model that is described on the paper section 4.7.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights
:members:
"""
weights = Raft_Stereo_Realtime_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=True,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(2, 1, 2, 2),
context_encoder_out_with_blocks=[True, True],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=True,
**kwargs,
)
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights
:members:
"""
weights = Raft_Stereo_Base_Weights.verify(weights)
return _raft_stereo(
weights=weights,
progress=progress,
shared_encoder_weight=False,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(1, 1, 2, 2),
feature_encoder_block=ResidualBlock,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_strides=(1, 1, 2, 2),
context_encoder_out_with_blocks=[True, True, False],
context_encoder_block=ResidualBlock,
# Correlation block
corr_num_levels=4,
corr_radius=4,
# Motion encoder
motion_encoder_corr_layers=(64, 64),
motion_encoder_flow_layers=(64, 64),
motion_encoder_out_channels=128,
# Update block
update_block_hidden_dims=[128, 128, 128],
# Flow head
flow_head_hidden_size=256,
# Mask predictor
mask_predictor_hidden_size=256,
use_mask_predictor=True,
slow_fast=False,
**kwargs,
)
from ._presets import StereoMatching # usort: skip
from ._augment import SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.ops import masks_to_boxes
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _check_interpolation
class SimpleCopyPaste(Transform):
def __init__(
self,
blending: bool = True,
resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.resize_interpolation = _check_interpolation(resize_interpolation)
self.blending = blending
self.antialias = antialias
def _copy_paste(
self,
image: Union[torch.Tensor, tv_tensors.Image],
target: Dict[str, Any],
paste_image: Union[torch.Tensor, tv_tensors.Image],
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = tv_tensors.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = tv_tensors.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = tv_tensors.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = cast(List[int], image.shape[-2:])
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images:
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2.
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=tv_tensors.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
return image, out_target
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[Union[torch.Tensor, tv_tensors.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, tv_tensors.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
elif isinstance(obj, tv_tensors.BoundingBoxes):
bboxes.append(obj)
elif isinstance(obj, tv_tensors.Mask):
masks.append(obj)
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Masks and Labels or OneHotLabels."
)
targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})
return images, targets
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[torch.Tensor],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, tv_tensors.Image):
flat_sample[i] = tv_tensors.wrap(output_images[c0], like=obj)
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, tv_tensors.BoundingBoxes):
flat_sample[i] = tv_tensors.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1
elif isinstance(obj, tv_tensors.Mask):
flat_sample[i] = tv_tensors.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
flat_sample[i] = tv_tensors.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1
def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
images, targets = self._extract_image_targets(flat_inputs)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images, output_targets = [], []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)
output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
antialias=self.antialias,
)
output_images.append(output_image)
output_targets.append(output_target)
# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_inputs, output_images, output_targets)
return tree_unflatten(flat_inputs, spec)
from typing import Any, Dict, List, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.prototype.tv_tensors import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import (
_FillType,
_get_fill,
_setup_fill_arg,
_setup_size,
get_bounding_boxes,
has_any,
is_pure_tensor,
query_size,
)
class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(
flat_inputs,
PIL.Image.Image,
tv_tensors.Image,
is_pure_tensor,
tv_tensors.Video,
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(flat_inputs, tv_tensors.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel):
raise TypeError(
f"If a BoundingBoxes is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_size(flat_inputs)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
needs_crop = new_height != height or new_width != width
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
bounding_boxes: Optional[torch.Tensor]
try:
bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError:
bounding_boxes = None
if needs_crop and bounding_boxes is not None:
format = bounding_boxes.format
bounding_boxes, canvas_size = F.crop_bounding_boxes(
bounding_boxes.as_subclass(torch.Tensor),
format=format,
top=top,
left=left,
height=new_height,
width=new_width,
)
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
needs_pad = pad_bottom != 0 or pad_right != 0
return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = self._call_kernel(
F.crop,
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, tv_tensors.Mask)):
inpt = tv_tensors.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, tv_tensors.BoundingBoxes):
inpt = tv_tensors.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
)
if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt))
inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
import functools
import warnings
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch
from torchvision import tv_tensors
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import is_pure_tensor
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
class PermuteDimensions(Transform):
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.permute(*dims)
class TransposeDimensions(Transform):
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torch import Tensor
from torchvision.transforms.v2 import functional as F, InterpolationMode
from torchvision.transforms.v2.functional._geometry import _check_interpolation
__all__ = ["StereoMatching"]
class StereoMatching(torch.nn.Module):
def __init__(
self,
*,
use_gray_scale: bool = False,
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
# pacify mypy
self.resize_size: Union[None, List]
if resize_size is not None:
self.resize_size = list(resize_size)
else:
self.resize_size = None
self.mean = list(mean)
self.std = list(std)
self.interpolation = _check_interpolation(interpolation)
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
def _process_image(img: PIL.Image.Image) -> Tensor:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
if self.resize_size is not None:
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=False)
if self.use_gray_scale is True:
img = F.rgb_to_grayscale(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
img = img.contiguous()
return img
left_image = _process_image(left_image)
right_image = _process_image(right_image)
return left_image, right_image
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
from typing import Any, Dict
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import Transform
class LabelToOneHot(Transform):
_transformed_types = (proto_tv_tensors.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return proto_tv_tensors.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
from ._label import Label, OneHotLabel
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(TVTensor):
categories: Optional[Sequence[str]]
@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
def __new__(
cls: Type[L],
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> L:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)
@classmethod
def from_category(
cls: Type[L],
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> L:
return cls(categories.index(category), categories=categories, **kwargs)
class Label(_LabelBase):
def to_categories(self) -> Any:
if self.categories is None:
raise RuntimeError("Label does not have categories")
return tree_map(lambda idx: self.categories[idx], self.tolist()) # type: ignore[index]
class OneHotLabel(_LabelBase):
def __new__(
cls,
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
return one_hot_label
import collections.abc
import difflib
import io
import mmap
import platform
from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union
import numpy as np
import torch
from torchvision._utils import sequence_to_str
__all__ = [
"add_suggestion",
"fromfile",
"ReadOnlyTensorBuffer",
]
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
D = TypeVar("D")
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)
buffer: Union[memoryview, bytearray]
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (AttributeError, PermissionError, io.UnsupportedOperation):
buffer = _read_mutable_buffer_fallback(file, count, item_size)
else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
class ReadOnlyTensorBuffer:
def __init__(self, tensor: torch.Tensor) -> None:
self._memory = memoryview(tensor.numpy()) # type: ignore[arg-type]
self._cursor: int = 0
def tell(self) -> int:
return self._cursor
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
self._cursor = offset
elif whence == io.SEEK_CUR:
self._cursor += offset
pass
elif whence == io.SEEK_END:
self._cursor = len(self._memory) + offset
else:
raise ValueError(
f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, "
f"but got {repr(whence)} instead"
)
return self.tell()
def read(self, size: int = -1) -> bytes:
cursor = self.tell()
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment