Unverified Commit 0fcfaa13 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Add stereo preset transforms (#6549)



* Added transforms for Stereo Matching

* changed implicit Y scaling to 0.

* Adressed some comments

* addressed type hint

* Added interpolation random interpolation strategy

* Aligned crop get params

* fixed bug in RandomErase

* Adressed scaling and typos

* Adressed occlusion typo

* Changed parameter order in F.erase

* fixed random erase

* Added inference preset transform for stereo matching

* added contiguous reshape to output tensors

* Adressed comments

* Modified the transform preset to use Tuple[int, int]

* adressed NITs

* added grayscale transform, align resize -> mask

* changed max disparity default behaviour

* added fixed resize, changed masking in sparse flow masking

* update to align with argparse

* changed default mask in asymetric pairs

* moved grayscale order

* changed grayscale api to accept to tensor variant

* mypy fix

* changed resize specs

* adressed nits

* added type hints

* mypy fix

* mypy fix

* mypy fix
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 2c1022e3
from typing import Optional, Tuple, Union
import torch
import transforms as T
class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()
transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]
if use_grayscale:
transforms.append(T.ConvertToGrayscale())
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)
class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:
if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
super().__init__()
transforms = [T.ToTensor()]
# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
if gpu_transforms:
transforms.append(T.ToGPU())
# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]
if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())
transforms.extend(color_transforms)
transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)
This diff is collapsed.
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torch import Tensor
from . import functional as F, InterpolationMode
__all__ = ["StereoMatching"]
class StereoMatching(torch.nn.Module):
def __init__(
self,
*,
use_gray_scale: bool = False,
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
# pacify mypy
self.resize_size: Union[None, List]
if resize_size is not None:
self.resize_size = list(resize_size)
else:
self.resize_size = None
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
def _process_image(img: PIL.Image.Image) -> Tensor:
if self.resize_size is not None:
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
if self.use_gray_scale is True:
img = F.rgb_to_grayscale(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
img = img.contiguous()
return img
left_image = _process_image(left_image)
right_image = _process_image(right_image)
return left_image, right_image
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment