Unverified Commit 0fcfaa13 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Add stereo preset transforms (#6549)



* Added transforms for Stereo Matching

* changed implicit Y scaling to 0.

* Adressed some comments

* addressed type hint

* Added interpolation random interpolation strategy

* Aligned crop get params

* fixed bug in RandomErase

* Adressed scaling and typos

* Adressed occlusion typo

* Changed parameter order in F.erase

* fixed random erase

* Added inference preset transform for stereo matching

* added contiguous reshape to output tensors

* Adressed comments

* Modified the transform preset to use Tuple[int, int]

* adressed NITs

* added grayscale transform, align resize -> mask

* changed max disparity default behaviour

* added fixed resize, changed masking in sparse flow masking

* update to align with argparse

* changed default mask in asymetric pairs

* moved grayscale order

* changed grayscale api to accept to tensor variant

* mypy fix

* changed resize specs

* adressed nits

* added type hints

* mypy fix

* mypy fix

* mypy fix
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 2c1022e3
from typing import Optional, Tuple, Union
import torch
import transforms as T
class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()
transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]
if use_grayscale:
transforms.append(T.ConvertToGrayscale())
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)
class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:
if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
super().__init__()
transforms = [T.ToTensor()]
# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
if gpu_transforms:
transforms.append(T.ToGPU())
# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]
if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())
transforms.extend(color_transforms)
transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)
import random
from typing import Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch import Tensor
T_FLOW = Union[Tensor, np.ndarray, None]
T_MASK = Union[Tensor, np.ndarray, None]
T_STEREO_TENSOR = Tuple[Tensor, Tensor]
T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
return (low - high) * torch.rand(size) + high
class InterpolationStrategy:
_valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
def __init__(self, mode: str = "mixed") -> None:
if mode not in self._valid_modes:
raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
if mode == "mixed":
self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
elif mode == "bicubic":
self.strategies = [F.InterpolationMode.BICUBIC]
elif mode == "bilinear":
self.strategies = [F.InterpolationMode.BILINEAR]
def __call__(self) -> F.InterpolationMode:
return random.choice(self.strategies)
@classmethod
def is_valid(mode: str) -> bool:
return mode in InterpolationStrategy._valid_modes
@property
def valid_modes() -> List[str]:
return InterpolationStrategy._valid_modes
class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
if images[0].shape != images[1].shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = images[0].shape[-2:]
if disparities[0] is not None and disparities[0].shape != (1, h, w):
raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
if masks[0] is not None:
if masks[0].shape != (h, w):
raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
if masks[0].dtype != torch.bool:
raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
return images, disparities, masks
class ConvertToGrayscale(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
return (img_left, img_right), disparities, masks
class MakeValidDisparityMask(torch.nn.Module):
def __init__(self, max_disparity: Optional[int] = 256) -> None:
super().__init__()
self.max_disparity = max_disparity
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
valid_masks = tuple(
torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
for idx, mask in enumerate(masks)
)
valid_masks = tuple(
torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
if self.max_disparity is not None:
valid_masks = tuple(
torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
return images, disparities, valid_masks
class ToGPU(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
dev_images = tuple(image.cuda() for image in images)
dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
return dev_images, dev_disparities, dev_masks
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self.dtype = dtype
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class Normalize(torch.nn.Module):
def __init__(self, mean: List[float], std: List[float]) -> None:
super().__init__()
self.mean = mean
self.std = std
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.normalize(images[0], mean=self.mean, std=self.std)
img_right = F.normalize(images[1], mean=self.mean, std=self.std)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class ToTensor(torch.nn.Module):
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if images[0] is None:
raise ValueError("img_left is None")
if images[1] is None:
raise ValueError("img_right is None")
img_left = F.pil_to_tensor(images[0])
img_right = F.pil_to_tensor(images[1])
disparity_tensors = ()
mask_tensors = ()
for idx in range(2):
disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
return (img_left, img_right), disparity_tensors, mask_tensors
class AsymmetricColorJitter(T.ColorJitter):
# p determines the probability of doing asymmetric vs symmetric color jittering
def __init__(
self,
brightness: T_COLOR_AUG_PARAM = 0,
contrast: T_COLOR_AUG_PARAM = 0,
saturation: T_COLOR_AUG_PARAM = 0,
hue: T_COLOR_AUG_PARAM = 0,
p: float = 0.2,
):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = super().forward(images[0])
img_right = super().forward(images[1])
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = super().forward(batch)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class AsymetricGammaAdjust(torch.nn.Module):
def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
super().__init__()
self.gamma_range = gamma_range
self.gain = gain
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = F.adjust_gamma(batch, gamma, gain=self.gain)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class RandomErase(torch.nn.Module):
# Produces multiple symetric random erasures
# these can be viewed as occlusions present in both camera views.
# Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
def __init__(
self,
p: float = 0.5,
erase_px_range: Tuple[int, int] = (50, 100),
value: Union[Tensor, float] = 0,
inplace: bool = False,
max_erase: int = 2,
):
super().__init__()
self.min_px_erase = erase_px_range[0]
self.max_px_erase = erase_px_range[1]
if self.max_px_erase < 0:
raise ValueError("erase_px_range[1] should be equal or greater than 0")
if self.min_px_erase < 0:
raise ValueError("erase_px_range[0] should be equal or greater than 0")
if self.min_px_erase > self.max_px_erase:
raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
self.p = p
self.value = value
self.inplace = inplace
self.max_erase = max_erase
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
return images, disparities, masks
image_left, image_right = images
mask_left, mask_right = masks
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
y, x, h, w, v = self._get_params(image_left)
image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
# similarly to optical flow occlusion prediction, we consider
# any erasure pixels that are in both images to be occluded therefore
# we mark them as invalid
if mask_left is not None:
mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
if mask_right is not None:
mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
return (image_left, image_right), disparities, (mask_left, mask_right)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_erase, self.max_px_erase),
random.randint(self.min_px_erase, self.max_px_erase),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
return crop_y, crop_x, crop_h, crop_w, self.value
class RandomOcclusion(torch.nn.Module):
# This adds an occlusion in the right image
# the occluded patch works as a patch erase where the erase value is the mean
# of the pixels from the selected zone
def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
super().__init__()
self.min_px_occlusion = occlusion_px_range[0]
self.max_px_occlusion = occlusion_px_range[1]
if self.max_px_occlusion < 0:
raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
if self.min_px_occlusion < 0:
raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
if self.min_px_occlusion > self.max_px_occlusion:
raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
self.p = p
self.inplace = inplace
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
left_image, right_image = images
if torch.rand(1) < self.p:
return images, disparities, masks
y, x, h, w, v = self._get_params(right_image)
right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
return ((left_image, right_image), disparities, masks)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_occlusion, self.max_px_occlusion),
random.randint(self.min_px_occlusion, self.max_px_occlusion),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
class RandomSpatialShift(torch.nn.Module):
# This transform applies a vertical shift and a slight angle rotation and the same time
def __init__(
self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
) -> None:
super().__init__()
self.p = p
self.max_angle = max_angle
self.max_px_shift = max_px_shift
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
# the transform is applied only on the right image
# in order to mimic slight calibration issues
img_left, img_right = images
INTERP_MODE = self._interpolation_mode_strategy()
if torch.rand(1) < self.p:
# [0, 1] -> [-a, a]
shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
# sample center point for the rotation matrix
y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
# apply affine transformations
img_right = F.affine(
img_right,
angle=angle,
translate=[0, shift], # translation only on the y axis
center=[x, y],
scale=1.0,
shear=0.0,
interpolation=INTERP_MODE,
)
return ((img_left, img_right), disparities, masks)
class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
if dsp_right is not None and torch.rand(1) < self.p:
img_left, img_right = F.hflip(img_left), F.hflip(img_right)
dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
if mask_left is not None and mask_right is not None:
mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
return images, disparities, masks
class Resize(torch.nn.Module):
def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
super().__init__()
self.resize_size = list(resize_size) # doing this to keep mypy happy
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
resized_images = ()
resized_disparities = ()
resized_masks = ()
INTERP_MODE = self._interpolation_mode_strategy()
for img in images:
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE),)
for dsp in disparities:
if dsp is not None:
# rescale disparity to match the new image size
scale_x = self.resize_size[1] / dsp.shape[-1]
resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
else:
resized_disparities += (None,)
for mask in masks:
if mask is not None:
resized_masks += (
# we squeeze and unsqueeze because the API requires > 3D tensors
F.resize(
mask.unsqueeze(0),
self.resize_size,
interpolation=F.InterpolationMode.NEAREST,
).squeeze(0),
)
else:
resized_masks += (None,)
return resized_images, resized_disparities, resized_masks
class RandomRescaleAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different resuts (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(
self,
crop_size: Tuple[int, int],
scale_range: Tuple[float, float] = (-0.2, 0.5),
rescale_prob: float = 0.8,
scaling_type: str = "exponential",
interpolation_type: str = "bilinear",
) -> None:
super().__init__()
self.crop_size = crop_size
self.min_scale = scale_range[0]
self.max_scale = scale_range[1]
self.rescale_prob = rescale_prob
self.scaling_type = scaling_type
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
if self.scaling_type == "linear" and self.min_scale < 0:
raise ValueError("min_scale must be >= 0 for linear scaling")
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
INTERP_MODE = self._interpolation_mode_strategy()
# randomly sample scale
h, w = img_left.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
# exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
# 2 to the power of that random value. This final scale distribution will have a different
# mean and variance than a uniform distribution. Note that a scale of 1 will result in
# in a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
# of 0.5X the original size.
if self.scaling_type == "exponential":
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
# linear scaling will draw a random scale in (min_scale, max_scale)
elif self.scaling_type == "linear":
scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale = max(scale, min_scale)
new_h, new_w = round(h * scale), round(w * scale)
if torch.rand(1).item() < self.rescale_prob:
# rescale the images
img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
resized_masks, resized_disparities = (), ()
for disparity, mask in zip(disparities, masks):
if disparity is not None:
if mask is None:
resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
# rescale the disparity
resized_disparity = (
resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
)
resized_mask = None
else:
resized_disparity, resized_mask = _resize_sparse_flow(
disparity, mask, scale_x=scale, scale_y=scale
)
resized_masks += (resized_mask,)
resized_disparities += (resized_disparity,)
else:
resized_disparities = disparities
resized_masks = masks
disparities = resized_disparities
masks = resized_masks
# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure it matters much
y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_left is not None:
dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_right is not None:
dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks = ()
for mask in masks:
if mask is not None:
mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks += (mask,)
return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
def _resize_sparse_flow(
flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
) -> Tuple[Tensor, Tensor]:
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
h, w = flow.shape[-2:]
h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]
valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new *= scale_x
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
return flow_new, valid_new.bool()
class Compose(torch.nn.Module):
def __init__(self, transforms: List[Callable]):
super().__init__()
self.transforms = transforms
@torch.inference_mode()
def forward(self, images, disparities, masks):
for t in self.transforms:
images, disparities, masks = t(images, disparities, masks)
return images, disparities, masks
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torch import Tensor
from . import functional as F, InterpolationMode
__all__ = ["StereoMatching"]
class StereoMatching(torch.nn.Module):
def __init__(
self,
*,
use_gray_scale: bool = False,
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
# pacify mypy
self.resize_size: Union[None, List]
if resize_size is not None:
self.resize_size = list(resize_size)
else:
self.resize_size = None
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
def _process_image(img: PIL.Image.Image) -> Tensor:
if self.resize_size is not None:
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
if self.use_gray_scale is True:
img = F.rgb_to_grayscale(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
img = img.contiguous()
return img
left_image = _process_image(left_image)
right_image = _process_image(right_image)
return left_image, right_image
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment