Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
from typing import List, Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid
def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
"""Function to create a 2D Gaussian kernel."""
x = torch.arange(kernel_size, dtype=torch.float32)
y = torch.arange(kernel_size, dtype=torch.float32)
x = x - (kernel_size - 1) / 2
y = y - (kernel_size - 1) / 2
x, y = torch.meshgrid(x, y)
grid = (x**2 + y**2) / (2 * sigma**2)
kernel = torch.exp(-grid)
kernel = kernel / kernel.sum()
return kernel
def _sequence_loss_fn(
flow_preds: List[Tensor],
flow_gt: Tensor,
valid_flow_mask: Optional[Tensor],
gamma: Tensor,
max_flow: int = 256,
exclude_large: bool = False,
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
torch._assert(
gamma < 1,
"sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
)
if exclude_large:
# exclude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
else:
valid_flow_mask = flow_norm < max_flow
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
if valid_flow_mask is not None:
abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
# allocating on CPU and moving to device during run-time can force
# an unwanted GPU synchronization that produces a large overhead
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class SequenceLoss(nn.Module):
def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
"""
Args:
gamma: value for the exponential weighting of the loss across frames
max_flow: maximum flow value to exclude
exclude_large_flows: whether to exclude large flows
"""
super().__init__()
self.max_flow = max_flow
self.excluding_large = exclude_large_flows
self.register_buffer("gamma", torch.tensor([gamma]))
# cache the scale factor for the loss
self._weights = None
def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of flow predictions of shape (batch_size, C, H, W)
flow_gt: ground truth flow of shape (batch_size, C, H, W)
valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
"""
loss, weights = _sequence_loss_fn(
flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _ssim_loss_fn(
source: Tensor,
reference: Tensor,
kernel: Tensor,
eps: float = 1e-8,
c1: float = 0.01**2,
c2: float = 0.03**2,
use_padding: bool = False,
) -> Tensor:
# ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
# ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim
torch._assert(
source.ndim == reference.ndim == 4,
"SSIM: `source` and `reference` must be 4-dimensional tensors",
)
torch._assert(
source.shape == reference.shape,
"SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
source.shape, reference.shape
),
)
B, C, H, W = source.shape
kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
if use_padding:
pad_size = kernel.shape[2] // 2
source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")
mu1 = F.conv2d(source, kernel, groups=C)
mu2 = F.conv2d(reference, kernel, groups=C)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)
sigma1_sq = mu_img1_sq - mu1_sq
sigma2_sq = mu_img2_sq - mu2_sq
sigma12 = mu_img1_mu2 - mu1_mu2
numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim = numerator / (denominator + eps)
# doing 1 - ssim because we want to maximize the ssim
return 1 - ssim.mean(dim=(1, 2, 3))
class SSIM(nn.Module):
def __init__(
self,
kernel_size: int = 11,
max_val: float = 1.0,
sigma: float = 1.5,
eps: float = 1e-12,
use_padding: bool = True,
) -> None:
"""SSIM loss function.
Args:
kernel_size: size of the Gaussian kernel
max_val: constant scaling factor
sigma: sigma of the Gaussian kernel
eps: constant for division by zero
use_padding: whether to pad the input tensor such that we have a score for each pixel
"""
super().__init__()
self.kernel_size = kernel_size
self.max_val = max_val
self.sigma = sigma
gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
self.register_buffer("gaussian_kernel", gaussian_kernel)
self.c1 = (0.01 * self.max_val) ** 2
self.c2 = (0.03 * self.max_val) ** 2
self.use_padding = use_padding
self.eps = eps
def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
"""
Args:
source: source image of shape (batch_size, C, H, W)
reference: reference image of shape (batch_size, C, H, W)
Returns:
SSIM loss of shape (batch_size,)
"""
return _ssim_loss_fn(
source,
reference,
kernel=self.gaussian_kernel,
c1=self.c1,
c2=self.c2,
use_padding=self.use_padding,
eps=self.eps,
)
def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
# ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202
torch._assert(
img_gx.ndim >= 3,
"smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
)
torch._assert(
img_gx.ndim == val_gx.ndim,
"smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
img_gx.ndim, val_gx.ndim
),
)
for idx in range(img_gx.ndim):
torch._assert(
(img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
"smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
img_gx.shape, val_gx.shape
),
)
# -3 is channel dimension
weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))
smoothness_x = img_gx * weights_x
smoothness_y = img_gy * weights_y
smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
return smoothness
class SmoothnessLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def _x_gradient(self, img: Tensor) -> Tensor:
if img.ndim > 4:
original_shape = img.shape
is_reshaped = True
img = img.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
grad = padded[..., :, :-1] - padded[..., :, 1:]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim > 4:
original_shape = x.shape
is_reshaped = True
x = x.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
grad = padded[..., :-1, :] - padded[..., 1:, :]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def forward(self, images: Tensor, vals: Tensor) -> Tensor:
"""
Args:
images: tensor of shape (D1, D2, ..., DN, C, H, W)
vals: tensor of shape (D1, D2, ..., DN, 1, H, W)
Returns:
smoothness loss of shape (D1, D2, ..., DN)
"""
img_gx = self._x_gradient(images)
img_gy = self._y_gradient(images)
val_gx = self._x_gradient(vals)
val_gy = self._y_gradient(vals)
return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)
def _flow_sequence_consistency_loss_fn(
flow_preds: List[Tensor],
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
# Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
# In the original paper, an additional refinement network is used to refine a flow prediction.
# Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
# which should be consistent with the previous step. In this implementation, we simplify the overall loss
# term and ignore left-right consistency loss or photometric loss which can be treated separately.
torch._assert(
rescale_factor <= 1.0,
"sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
rescale_factor
),
)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
N, B, C, H, W = flow_preds.shape
# rescale flow predictions to account for bilinear upsampling artifacts
if rescale_factor:
flow_preds = (
F.interpolate(
flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
)
) * rescale_factor
flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)
# force the next prediction to be similar to the previous prediction
abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0] - 1 # because we are comparing differences
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class FlowSequenceConsistencyLoss(nn.Module):
def __init__(
self,
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
) -> None:
super().__init__()
self.gamma = gamma
self.resize_factor = resize_factor
self.rescale_factor = rescale_factor
self.rescale_mode = rescale_mode
self._weights = None
def forward(self, flow_preds: List[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of tensors of shape (batch_size, C, H, W)
Returns:
sequence consistency loss of shape (batch_size,)
"""
loss, weights = _flow_sequence_consistency_loss_fn(
flow_preds,
gamma=self.gamma,
resize_factor=self.resize_factor,
rescale_factor=self.rescale_factor,
rescale_mode=self.rescale_mode,
weights=self._weights,
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
torch._assert(
source.shape == target.shape,
"psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
)
# ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))
class PSNRLoss(nn.Module):
def __init__(self, max_val: float = 256) -> None:
"""
Args:
max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.
"""
super().__init__()
self.max_val = max_val
def forward(self, source: Tensor, target: Tensor) -> Tensor:
"""
Args:
source: tensor of shape (D1, D2, ..., DN, C, H, W)
target: tensor of shape (D1, D2, ..., DN, C, H, W)
Returns:
psnr loss of shape (D1, D2, ..., DN)
"""
# multiply by -1 as we want to maximize the psnr
return -1 * _psnr_loss_fn(source, target, self.max_val)
class FlowPhotoMetricLoss(nn.Module):
def __init__(
self,
ssim_weight: float = 0.85,
ssim_window_size: int = 11,
ssim_max_val: float = 1.0,
ssim_sigma: float = 1.5,
ssim_eps: float = 1e-12,
ssim_use_padding: bool = True,
max_displacement_ratio: float = 0.15,
) -> None:
super().__init__()
self._ssim_loss = SSIM(
kernel_size=ssim_window_size,
max_val=ssim_max_val,
sigma=ssim_sigma,
eps=ssim_eps,
use_padding=ssim_use_padding,
)
self._L1_weight = 1 - ssim_weight
self._SSIM_weight = ssim_weight
self._max_displacement_ratio = max_displacement_ratio
def forward(
self,
source: Tensor,
reference: Tensor,
flow_pred: Tensor,
valid_mask: Optional[Tensor] = None,
):
"""
Args:
source: tensor of shape (B, C, H, W)
reference: tensor of shape (B, C, H, W)
flow_pred: tensor of shape (B, 2, H, W)
valid_mask: tensor of shape (B, H, W) or None
Returns:
photometric loss of shape
"""
torch._assert(
source.ndim == 4,
"FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
)
torch._assert(
reference.ndim == source.ndim,
"FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
source.ndim, reference.ndim
),
)
torch._assert(
flow_pred.shape[1] == 2,
"FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
)
torch._assert(
flow_pred.ndim == 4,
"FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
)
B, C, H, W = source.shape
flow_channels = flow_pred.shape[1]
max_displacements = []
for dim in range(flow_channels):
shape_index = -1 - dim
max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))
# mask out all pixels that have larger flow than the max flow allowed
max_flow_mask = torch.logical_and(
*[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
)
if valid_mask is not None:
valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
else:
valid_mask = max_flow_mask.unsqueeze(1)
grid = make_coords_grid(B, H, W, device=str(source.device))
resampled_grids = grid - flow_pred
resampled_grids = resampled_grids.permute(0, 2, 3, 1)
resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")
# compute SSIM loss
ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss
return loss.mean()
from typing import Dict, List, Optional, Tuple
from torch import Tensor
AVAILABLE_METRICS = ["mae", "rmse", "epe", "bad1", "bad2", "epe", "1px", "3px", "5px", "fl-all", "relepe"]
def compute_metrics(
flow_pred: Tensor, flow_gt: Tensor, valid_flow_mask: Optional[Tensor], metrics: List[str]
) -> Tuple[Dict[str, float], int]:
for m in metrics:
if m not in AVAILABLE_METRICS:
raise ValueError(f"Invalid metric: {m}. Valid metrics are: {AVAILABLE_METRICS}")
metrics_dict = {}
pixels_diffs = (flow_pred - flow_gt).abs()
# there is no Y flow in Stereo Matching, therefore flow.abs() = flow.pow(2).sum(dim=1).sqrt()
flow_norm = flow_gt.abs()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
pixels_diffs = pixels_diffs[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
num_pixels = pixels_diffs.numel()
if "bad1" in metrics:
metrics_dict["bad1"] = (pixels_diffs > 1).float().mean().item()
if "bad2" in metrics:
metrics_dict["bad2"] = (pixels_diffs > 2).float().mean().item()
if "mae" in metrics:
metrics_dict["mae"] = pixels_diffs.mean().item()
if "rmse" in metrics:
metrics_dict["rmse"] = pixels_diffs.pow(2).mean().sqrt().item()
if "epe" in metrics:
metrics_dict["epe"] = pixels_diffs.mean().item()
if "1px" in metrics:
metrics_dict["1px"] = (pixels_diffs < 1).float().mean().item()
if "3px" in metrics:
metrics_dict["3px"] = (pixels_diffs < 3).float().mean().item()
if "5px" in metrics:
metrics_dict["5px"] = (pixels_diffs < 5).float().mean().item()
if "fl-all" in metrics:
metrics_dict["fl-all"] = ((pixels_diffs < 3) & ((pixels_diffs / flow_norm) < 0.05)).float().mean().item() * 100
if "relepe" in metrics:
metrics_dict["relepe"] = (pixels_diffs / flow_norm).mean().item()
return metrics_dict, num_pixels
import torch
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
def unfreeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.train()
import torch.nn.functional as F
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
import os
from typing import List
import numpy as np
import torch
from torch import Tensor
from torchvision.utils import make_grid
@torch.no_grad()
def make_disparity_image(disparity: Tensor):
# normalize image to [0, 1]
disparity = disparity.detach().cpu()
disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
return disparity
@torch.no_grad()
def make_disparity_image_pairs(disparity: Tensor, image: Tensor):
disparity = make_disparity_image(disparity)
# image is in [-1, 1], bring it to [0, 1]
image = image.detach().cpu()
image = image * 0.5 + 0.5
return disparity, image
@torch.no_grad()
def make_disparity_sequence(disparities: List[Tensor]):
# convert each disparity to [0, 1]
for idx, disparity_batch in enumerate(disparities):
disparities[idx] = torch.stack(list(map(make_disparity_image, disparity_batch)))
# make the list into a batch
disparity_sequences = torch.stack(disparities)
return disparity_sequences
@torch.no_grad()
def make_pair_grid(*inputs, orientation="horizontal"):
# make a grid of images with the outputs and references side by side
if orientation == "horizontal":
# interleave the outputs and references
canvas = torch.zeros_like(inputs[0])
canvas = torch.cat([canvas] * len(inputs), dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs), padding=16, normalize=True, scale_each=True)
elif orientation == "vertical":
# interleave the outputs and references
canvas = torch.cat(inputs, dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs[0]), padding=16, normalize=True, scale_each=True)
else:
raise ValueError("Unknown orientation: {}".format(orientation))
return grid
@torch.no_grad()
def make_training_sample_grid(
left_images: Tensor,
right_images: Tensor,
disparities: Tensor,
masks: Tensor,
predictions: List[Tensor],
) -> np.ndarray:
# detach images and renormalize to [0, 1]
images_left = left_images.detach().cpu() * 0.5 + 0.5
images_right = right_images.detach().cpu() * 0.5 + 0.5
# detach the disparties and predictions
disparities = disparities.detach().cpu()
predictions = predictions[-1].detach().cpu()
# keep only the first channel of pixels, and repeat it 3 times
disparities = disparities[:, :1, ...].repeat(1, 3, 1, 1)
predictions = predictions[:, :1, ...].repeat(1, 3, 1, 1)
# unsqueeze and repeat the masks
masks = masks.detach().cpu().unsqueeze(1).repeat(1, 3, 1, 1)
# make a grid that will self normalize across the batch
pred_grid = make_pair_grid(images_left, images_right, masks, disparities, predictions, orientation="horizontal")
pred_grid = pred_grid.permute(1, 2, 0).numpy()
pred_grid = (pred_grid * 255).astype(np.uint8)
return pred_grid
@torch.no_grad()
def make_disparity_sequence_grid(predictions: List[Tensor], disparities: Tensor) -> np.ndarray:
# right most we will be adding the ground truth
seq_len = len(predictions) + 1
predictions = list(map(lambda x: x[:, :1, :, :].detach().cpu(), predictions + [disparities]))
sequence = make_disparity_sequence(predictions)
# swap axes to have the in the correct order for each batch sample
sequence = torch.swapaxes(sequence, 0, 1).contiguous().reshape(-1, 1, disparities.shape[-2], disparities.shape[-1])
sequence = make_grid(sequence, nrow=seq_len, padding=16, normalize=True, scale_each=True)
sequence = sequence.permute(1, 2, 0).numpy()
sequence = (sequence * 255).astype(np.uint8)
return sequence
@torch.no_grad()
def make_prediction_image_side_to_side(
predictions: Tensor, disparities: Tensor, valid_mask: Tensor, save_path: str, prefix: str
) -> None:
import matplotlib.pyplot as plt
# normalize the predictions and disparities in [0, 1]
predictions = (predictions - predictions.min()) / (predictions.max() - predictions.min())
disparities = (disparities - disparities.min()) / (disparities.max() - disparities.min())
predictions = predictions * valid_mask
disparities = disparities * valid_mask
predictions = predictions.detach().cpu()
disparities = disparities.detach().cpu()
for idx, (pred, gt) in enumerate(zip(predictions, disparities)):
pred = pred.permute(1, 2, 0).numpy()
gt = gt.permute(1, 2, 0).numpy()
# plot pred and gt side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pred)
ax[0].set_title("Prediction")
ax[1].imshow(gt)
ax[1].set_title("Ground Truth")
save_name = os.path.join(save_path, "{}_{}.png".format(prefix, idx))
plt.savefig(save_name)
plt.close()
......@@ -22,43 +22,50 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs.
### Faster R-CNN ResNet-50 FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
### Faster R-CNN MobileNetV3-Large FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```
### Faster R-CNN MobileNetV3-Large 320 FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```
### FCOS ResNet-50 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fcos_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
### SSD300 VGG16
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
--weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES
```
### SSDlite320 MobileNetV3-Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
--weight-decay 0.00004 --data-augmentation ssdlite
......@@ -67,16 +74,15 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
### Mask R-CNN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
### Keypoint R-CNN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
torchrun --nproc_per_node=8 train.py\
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
--lr-steps 36 43 --aspect-ratio-group-factor 3
--lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
import json
import tempfile
import numpy as np
import copy
import time
import torch
import torch._six
import io
from contextlib import redirect_stdout
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import numpy as np
import pycocotools.mask as mask_util
from collections import defaultdict
import torch
import utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class CocoEvaluator(object):
class CocoEvaluator:
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
if not isinstance(iou_types, (list, tuple)):
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
......@@ -36,7 +31,8 @@ class CocoEvaluator(object):
for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
coco_dt = loadRes(self.coco_gt, results) if results else COCO()
with redirect_stdout(io.StringIO()):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]
coco_eval.cocoDt = coco_dt
......@@ -56,18 +52,17 @@ class CocoEvaluator(object):
def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
print(f"IoU metric: {iou_type}")
coco_eval.summarize()
def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
if iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
if iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
raise ValueError(f"Unknown iou type {iou_type}")
def prepare_for_coco_detection(self, predictions):
coco_results = []
......@@ -109,8 +104,7 @@ class CocoEvaluator(object):
labels = prediction["labels"].tolist()
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in masks
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
......@@ -146,7 +140,7 @@ class CocoEvaluator(object):
{
"image_id": original_id,
"category_id": labels[k],
'keypoints': keypoint,
"keypoints": keypoint,
"score": scores[k],
}
for k, keypoint in enumerate(keypoints)
......@@ -192,161 +186,7 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################
# Ideally, pycocotools wouldn't have hard-coded prints
# so that we could avoid copy-pasting those two functions
def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann
if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])
# print('index created!')
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
maskUtils = mask_util
def loadRes(self, resFile):
"""
Load result file and return a result api object.
Args:
self (obj): coco object with ground truth annotations
resFile (str): file name of result file
Returns:
res (obj): result api object
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]
# print('Loading and preparing results...')
# tic = time.time()
if isinstance(resFile, torch._six.string_classes):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, 'results in not an array of objects'
annsImgIds = [ann['image_id'] for ann in anns]
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
'Results do not correspond to current coco set'
if 'caption' in anns[0]:
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
for id, ann in enumerate(anns):
ann['id'] = id + 1
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
bb = ann['bbox']
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
if 'segmentation' not in ann:
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann['area'] = bb[2] * bb[3]
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'segmentation' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann['area'] = maskUtils.area(ann['segmentation'])
if 'bbox' not in ann:
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
ann['area'] = (x2 - x1) * (y2 - y1)
ann['id'] = id + 1
ann['bbox'] = [x1, y1, x2 - x1, y2 - y1]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
res.dataset['annotations'] = anns
createIndex(res)
return res
def evaluate(self):
'''
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs
#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
def evaluate(imgs):
with redirect_stdout(io.StringIO()):
imgs.evaluate()
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
import copy
import os
from PIL import Image
import torch
import torch.utils.data
import torchvision
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
import transforms as T
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
......@@ -47,16 +25,15 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks
class ConvertCocoPolysToMask(object):
class ConvertCocoPolysToMask:
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if obj['iscrowd'] == 0]
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
......@@ -119,7 +96,7 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# keypoints task have a slight different criteria for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
......@@ -129,7 +106,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return True
return False
assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
......@@ -147,55 +123,56 @@ def convert_to_coco_api(ds):
coco_ds = COCO()
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1
dataset = {'images': [], 'categories': [], 'annotations': []}
dataset = {"images": [], "categories": [], "annotations": []}
categories = set()
for img_idx in range(len(ds)):
# find better way to get target
# targets = ds.get_annotations(img_idx)
img, targets = ds[img_idx]
image_id = targets["image_id"].item()
image_id = targets["image_id"]
img_dict = {}
img_dict['id'] = image_id
img_dict['height'] = img.shape[-2]
img_dict['width'] = img.shape[-1]
dataset['images'].append(img_dict)
bboxes = targets["boxes"]
img_dict["id"] = image_id
img_dict["height"] = img.shape[-2]
img_dict["width"] = img.shape[-1]
dataset["images"].append(img_dict)
bboxes = targets["boxes"].clone()
bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist()
labels = targets['labels'].tolist()
areas = targets['area'].tolist()
iscrowd = targets['iscrowd'].tolist()
if 'masks' in targets:
masks = targets['masks']
labels = targets["labels"].tolist()
areas = targets["area"].tolist()
iscrowd = targets["iscrowd"].tolist()
if "masks" in targets:
masks = targets["masks"]
# make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if 'keypoints' in targets:
keypoints = targets['keypoints']
if "keypoints" in targets:
keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes)
for i in range(num_objs):
ann = {}
ann['image_id'] = image_id
ann['bbox'] = bboxes[i]
ann['category_id'] = labels[i]
ann["image_id"] = image_id
ann["bbox"] = bboxes[i]
ann["category_id"] = labels[i]
categories.add(labels[i])
ann['area'] = areas[i]
ann['iscrowd'] = iscrowd[i]
ann['id'] = ann_id
if 'masks' in targets:
ann["area"] = areas[i]
ann["iscrowd"] = iscrowd[i]
ann["id"] = ann_id
if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if 'keypoints' in targets:
ann['keypoints'] = keypoints[i]
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
dataset['annotations'].append(ann)
if "keypoints" in targets:
ann["keypoints"] = keypoints[i]
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset["annotations"].append(ann)
ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)]
dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset
coco_ds.createIndex()
return coco_ds
def get_coco_api_from_dataset(dataset):
# FIXME: This is... awful?
for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
......@@ -208,11 +185,11 @@ def get_coco_api_from_dataset(dataset):
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super(CocoDetection, self).__init__(img_folder, ann_file)
super().__init__(img_folder, ann_file)
self._transforms = transforms
def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
......@@ -220,7 +197,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target
def get_coco(root, image_set, transforms, mode='instances'):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
......@@ -228,16 +205,25 @@ def get_coco(root, image_set, transforms, mode='instances'):
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}
t = [ConvertCocoPolysToMask()]
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train":
......@@ -246,7 +232,3 @@ def get_coco(root, image_set, transforms, mode='instances'):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
return dataset
def get_coco_kp(root, image_set, transforms):
return get_coco(root, image_set, transforms, mode="person_keypoints")
import math
import sys
import time
import torch
import torch
import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils
from coco_eval import CocoEvaluator
from coco_utils import get_coco_api_from_dataset
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = f"Epoch: [{epoch}]"
lr_scheduler = None
if epoch == 0:
warmup_factor = 1. / 1000
warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
......@@ -38,11 +38,16 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value = losses_reduced.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward()
optimizer.step()
......@@ -67,7 +72,7 @@ def _get_iou_types(model):
return iou_types
@torch.no_grad()
@torch.inference_mode()
def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
......@@ -75,7 +80,7 @@ def evaluate(model, data_loader, device):
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model)
......@@ -92,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
......
import bisect
from collections import defaultdict
import copy
from itertools import repeat, chain
import math
import numpy as np
from collections import defaultdict
from itertools import chain, repeat
import numpy as np
import torch
import torch.utils.data
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
import torchvision
from PIL import Image
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
def _repeat_to_at_least(iterable, n):
......@@ -34,12 +33,10 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
......@@ -66,10 +63,9 @@ class GroupedBatchSampler(BatchSampler):
expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches
if num_remaining > 0:
# for the remaining batches, take first the buffers with largest number
# for the remaining batches, take first the buffers with the largest number
# of elements
for group_id, _ in sorted(buffer_per_group.items(),
key=lambda x: len(x[1]), reverse=True):
for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id])
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
......@@ -85,10 +81,12 @@ class GroupedBatchSampler(BatchSampler):
def _compute_aspect_ratios_slow(dataset, indices=None):
print("Your dataset doesn't support the fast path for "
print(
"Your dataset doesn't support the fast path for "
"computing the aspect ratios, so will iterate over "
"the full dataset and load every image instead. "
"This might take some time...")
"This might take some time..."
)
if indices is None:
indices = range(len(dataset))
......@@ -104,9 +102,12 @@ def _compute_aspect_ratios_slow(dataset, indices=None):
sampler = SubsetSampler(indices)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, sampler=sampler,
dataset,
batch_size=1,
sampler=sampler,
num_workers=14, # you might want to increase it for faster processing
collate_fn=lambda x: x[0])
collate_fn=lambda x: x[0],
)
aspect_ratios = []
with tqdm(total=len(dataset)) as pbar:
for _i, (img, _) in enumerate(data_loader):
......@@ -190,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins))
print("Count of instances per bin: {}".format(counts))
print(f"Using {fbins} as bins for aspect ratio quantization")
print(f"Count of instances per bin: {counts}")
return groups
import transforms as T
from collections import defaultdict
import torch
import transforms as reference_transforms
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
import torchvision.tv_tensors
return torchvision.transforms.v2, torchvision.tv_tensors
else:
return reference_transforms, None
class DetectionPresetTrain:
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
if data_augmentation == 'hflip':
self.transforms = T.Compose([
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter.
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):
T, tv_tensors = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "tv_tensor":
transforms.append(T.ToImage())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if data_augmentation == "hflip":
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
elif data_augmentation == 'ssd':
self.transforms = T.Compose([
]
elif data_augmentation == "ssd":
fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
elif data_augmentation == 'ssdlite':
self.transforms = T.Compose([
]
elif data_augmentation == "ssdlite":
transforms += [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(),
T.ToPureTensor(),
]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
return self.transforms(img, target)
class DetectionPresetEval:
def __init__(self):
self.transforms = T.ToTensor()
def __init__(self, backend="pil", use_v2=False):
T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "tv_tensor":
transforms += [T.ToImage()]
else:
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [T.ToPureTensor()]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
return self.transforms(img, target)
......@@ -21,74 +21,125 @@ import datetime
import os
import time
import presets
import torch
import torch.utils.data
import torchvision
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco, get_coco_kp
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate
import presets
import utils
def get_dataset(name, image_set, transform, data_path):
paths = {
"coco": (data_path, get_coco, 91),
"coco_kp": (data_path, get_coco_kp, 2)
}
p, ds_fn, num_classes = paths[name]
ds = ds_fn(p, image_set=image_set, transforms=transform)
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste
def copypaste_collate_fn(batch):
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
return copypaste(*utils.collate_fn(batch))
def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
with_masks = "mask" in args.model
ds = get_coco(
root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes
def get_transform(train, data_augmentation):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
def get_transform(is_train, args):
if is_train:
return presets.DetectionPresetTrain(
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
return lambda img, target: (trans(img), target)
else:
return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
parser.add_argument('--dataset', default='coco', help='dataset')
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=2, type=int,
help='images per gpu, the total batch size is $NGPU x batch_size')
parser.add_argument('--epochs', default=26, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--lr', default=0.02, type=float,
help='initial learning rate, 0.02 is the default value for training '
'on 8 gpus and 2 images_per_gpu')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)')
parser.add_argument('--lr-step-size', default=8, type=int,
help='decrease lr every step-size epochs (multisteplr scheduler only)')
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
help='decrease lr every step-size epochs (multisteplr scheduler only)')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)')
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument(
"--dataset",
default="coco",
type=str,
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
)
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
)
parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument(
"--lr",
default=0.02,
type=float,
help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
)
parser.add_argument(
"--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
)
parser.add_argument(
"--lr-steps",
default=[16, 22],
nargs="+",
type=int,
help="decrease lr every step-size epochs (multisteplr scheduler only)",
)
parser.add_argument(
"--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
)
parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
parser.add_argument(
"--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
)
parser.add_argument(
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
......@@ -101,22 +152,43 @@ def get_args_parser(add_help=True):
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# Use CopyPaste augmentation training parameter
parser.add_argument(
"--use-copypaste",
action="store_true",
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser
def main(args):
if args.backend.lower() == "tv_tensor" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -125,17 +197,19 @@ def main(args):
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.use_deterministic_algorithms(True)
# Data loading code
print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation),
args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
dataset, num_classes = get_dataset(is_train=True, args=args)
dataset_test, _ = get_dataset(is_train=False, args=args)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
......@@ -144,27 +218,33 @@ def main(args):
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(
train_sampler, args.batch_size, drop_last=True)
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
train_collate_fn = copypaste_collate_fn
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
print("Creating model")
kwargs = {
"trainable_backbone_layers": args.trainable_backbone_layers
}
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation in ["multiscale", "lsj"]:
kwargs["_skip_resize"] = True
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model = torchvision.models.get_model(
args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -174,27 +254,50 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
params = [p for p in model.parameters() if p.requires_grad]
if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == 'multisteplr':
if args.lr_scheduler == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
elif args.lr_scheduler == 'cosineannealinglr':
elif args.lr_scheduler == "cosineannealinglr":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
"are supported.".format(args.lr_scheduler))
raise RuntimeError(
f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
torch.backends.cudnn.deterministic = True
evaluate(model, data_loader_test, device=device)
return
......@@ -203,29 +306,27 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
lr_scheduler.step()
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args,
'epoch': epoch
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"args": args,
"epoch": epoch,
}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
# evaluate after every epoch
evaluate(model, data_loader_test, device=device)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print(f"Training time {total_time_str}")
if __name__ == "__main__":
......
from typing import Dict, List, Optional, Tuple, Union
import torch
import torchvision
from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional
from torchvision import ops
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
def _flip_coco_person_keypoints(kps, width):
......@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return flipped_data
class Compose(object):
class Compose:
def __init__(self, transforms):
self.transforms = transforms
......@@ -28,12 +28,13 @@ class Compose(object):
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
_, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
......@@ -44,16 +45,39 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
return image, target
class ToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
class PILToTensor(nn.Module):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target
class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
......@@ -65,18 +89,19 @@ class RandomIoUCrop(nn.Module):
self.options = sampler_options
self.trials = trials
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None:
raise ValueError("The targets can't be None for this transform.")
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
orig_w, orig_h = F._get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)
while True:
# sample an option
......@@ -112,8 +137,9 @@ class RandomIoUCrop(nn.Module):
# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]],
dtype=boxes.dtype, device=boxes.device))
ious = torchvision.ops.boxes.box_iou(
boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
)
if ious.max() < min_jaccard_overlap:
continue
......@@ -130,14 +156,16 @@ class RandomIoUCrop(nn.Module):
class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5):
def __init__(
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
):
super().__init__()
if fill is None:
fill = [0., 0., 0.]
fill = [0.0, 0.0, 0.0]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
@torch.jit.unused
......@@ -146,18 +174,19 @@ class RandomZoomOut(nn.Module):
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
if torch.rand(1) < self.p:
if torch.rand(1) >= self.p:
return image, target
orig_w, orig_h = F._get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
......@@ -176,9 +205,11 @@ class RandomZoomOut(nn.Module):
image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor):
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \
image[..., :, (left + orig_w):] = v
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
..., :, (left + orig_w) :
] = v
if target is not None:
target["boxes"][:, 0::2] += left
......@@ -188,8 +219,14 @@ class RandomZoomOut(nn.Module):
class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5):
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast)
......@@ -197,11 +234,12 @@ class RandomPhotometricDistort(nn.Module):
self._saturation = T.ColorJitter(saturation=saturation)
self.p = p
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
......@@ -226,14 +264,338 @@ class RandomPhotometricDistort(nn.Module):
image = self._contrast(image)
if r[6] < self.p:
channels = F._get_image_num_channels(image)
channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image)
if is_pil:
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)
return image, target
class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
_, orig_height, orig_width = F.get_dimensions(image)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
)
return image, target
class FixedSizeCrop(nn.Module):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__()
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode
def _pad(self, img, target, padding):
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
padding = [pad_left, pad_top, pad_right, pad_bottom]
img = F.pad(img, padding, self.fill, self.padding_mode)
if target is not None:
target["boxes"][:, 0::2] += pad_left
target["boxes"][:, 1::2] += pad_top
if "masks" in target:
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
return img, target
def _crop(self, img, target, top, left, height, width):
img = F.crop(img, top, left, height, width)
if target is not None:
boxes = target["boxes"]
boxes[:, 0::2] -= left
boxes[:, 1::2] -= top
boxes[:, 0::2].clamp_(min=0, max=width)
boxes[:, 1::2].clamp_(min=0, max=height)
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
target["boxes"] = boxes[is_valid]
target["labels"] = target["labels"][is_valid]
if "masks" in target:
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
return img, target
def forward(self, img, target=None):
_, height, width = F.get_dimensions(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
img, target = self._crop(img, target, top, left, new_height, new_width)
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
return img, target
class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
_, orig_height, orig_width = F.get_dimensions(image)
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)
return image, target
def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target
# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)
paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
return image, out_target
class SimpleCopyPaste(torch.nn.Module):
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending
def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)
return output_images, output_targets
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s
from collections import defaultdict, deque
import datetime
import errno
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
class SmoothedValue(object):
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
......@@ -32,7 +32,7 @@ class SmoothedValue(object):
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
......@@ -63,11 +63,8 @@ class SmoothedValue(object):
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
def all_gather(data):
......@@ -98,7 +95,7 @@ def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
with torch.inference_mode():
names = []
values = []
# sort the keys so that they are consistent across processes
......@@ -113,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return reduced_dict
class MetricLogger(object):
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
......@@ -130,15 +127,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -151,31 +145,28 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join([
log_msg = self.delimiter.join(
[
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
......@@ -185,39 +176,34 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
def collate_fn(batch):
return tuple(zip(*batch))
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
def mkdir(path):
try:
os.makedirs(path)
......@@ -231,10 +217,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
......@@ -271,25 +258,25 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
args.dist_backend = "nccl"
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
# Optical flow reference training scripts
This folder contains reference training scripts for optical flow.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### RAFT Large
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_chairs \
--model raft_large \
--train-dataset chairs \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--epochs 72 \
--output-dir $chairs_dir
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model raft_large \
--train-dataset things \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--epochs 20 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
```
### Evaluation
```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2
```
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
final pass of Sintel-train. Results may vary slightly depending on the batch
size and the number of GPUs. For the most accurate results use 1 GPU and
`--batch-size 1`:
```
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
```
You can also evaluate on Kitti train:
```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2
Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679
```
import torch
import transforms as T
class OpticalFlowPresetEval(torch.nn.Module):
def __init__(self):
super().__init__()
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.ValidateModelInput(),
]
)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
class OpticalFlowPresetTrain(torch.nn.Module):
def __init__(
self,
*,
# RandomResizeAndCrop params
crop_size,
min_scale=-0.2,
max_scale=0.5,
stretch_prob=0.8,
# AsymmetricColorJitter params
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5 / 3.14,
# Random[H,V]Flip params
asymmetric_jitter_prob=0.2,
do_flip=True,
):
super().__init__()
transforms = [
T.PILToTensor(),
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.RandomResizeAndCrop(
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
),
]
if do_flip:
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
transforms += [
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.RandomErasing(max_erase=2),
T.MakeValidFlowMask(),
T.ValidateModelInput(),
]
self.transforms = T.Compose(transforms)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
import argparse
import warnings
from math import ceil
from pathlib import Path
import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")
@torch.no_grad()
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
device = torch.device(args.device)
model.eval()
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.workers,
)
num_flow_updates = num_flow_updates or args.num_flow_updates
def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched, so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.to(device), image2.to(device)
padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()
print(header, logger)
def evaluate(model, args):
val_datasets = args.val_dataset or []
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(img1, img2, flow, valid_flow_mask):
img1, img2 = trans(img1, img2)
if flow is not None and not isinstance(flow, torch.Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
else:
preprocessing = OpticalFlowPresetEval()
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes, so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_evaluate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
metrics.pop("f1")
logger.update(loss=loss, **metrics)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
scheduler.step()
def main(args):
utils.setup_ddp(args)
args.test_only = args.train_dataset is None
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
model = torchvision.models.get_model(args.model, weights=args.weights)
if args.distributed:
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
if args.resume is not None:
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
if args.test_only:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, args)
return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0
torch.backends.cudnn.benchmark = True
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
logger = utils.MetricLogger()
done = False
for epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(epoch)
train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
args=args,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {epoch} done. ", logger)
if not args.distributed or args.rank == 0:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if epoch % args.val_freq == 0 or done:
evaluate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)
parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume and weights should be in an exclusive arg group
parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask):
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")
if img1.shape != img2.shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = img1.shape[-2:]
if flow is not None and flow.shape != (2, h, w):
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
if valid_flow_mask is not None:
if valid_flow_mask.shape != (h, w):
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
if valid_flow_mask.dtype != torch.bool:
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")
return img1, img2, flow, valid_flow_mask
class MakeValidFlowMask(torch.nn.Module):
# This transform generates a valid_flow_mask if it doesn't exist.
# The flow is considered valid if ||flow||_inf < threshold
# This is a noop for Kitti and HD1K which already come with a built-in flow mask.
def __init__(self, threshold=1000):
super().__init__()
self.threshold = threshold
def forward(self, img1, img2, flow, valid_flow_mask):
if flow is not None and valid_flow_mask is None:
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0)
return img1, img2, flow, valid_flow_mask
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.convert_image_dtype(img1, dtype=self.dtype)
img2 = F.convert_image_dtype(img2, dtype=self.dtype)
img1 = img1.contiguous()
img2 = img2.contiguous()
return img1, img2, flow, valid_flow_mask
class Normalize(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean
self.std = std
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.normalize(img1, mean=self.mean, std=self.std)
img2 = F.normalize(img2, mean=self.mean, std=self.std)
return img1, img2, flow, valid_flow_mask
class PILToTensor(torch.nn.Module):
# Converts all inputs to tensors
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming
# for consistency with the rest, e.g. the segmentation reference.
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.pil_to_tensor(img1)
img2 = F.pil_to_tensor(img2)
if flow is not None:
flow = torch.from_numpy(flow)
if valid_flow_mask is not None:
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class AsymmetricColorJitter(T.ColorJitter):
# p determines the proba of doing asymmertric vs symmetric color jittering
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img1 = super().forward(img1)
img2 = super().forward(img2)
else:
# symmetric: same transform for img1 and img2
batch = torch.stack([img1, img2])
batch = super().forward(batch)
img1, img2 = batch[0], batch[1]
return img1, img2, flow, valid_flow_mask
class RandomErasing(T.RandomErasing):
# This only erases img2, and with an extra max_erase param
# This max_erase is needed because in the RAFT training ref does:
# 0 erasing with .5 proba
# 1 erase with .25 proba
# 2 erase with .25 proba
# and there's no accurate way to achieve this otherwise.
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase
if self.max_erase <= 0:
raise ValueError("max_raise should be greater than 0")
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value])
img2 = F.erase(img2, x, y, h, w, v, self.inplace)
return img1, img2, flow, valid_flow_mask
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
img1 = F.hflip(img1)
img2 = F.hflip(img2)
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.hflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class RandomVerticalFlip(T.RandomVerticalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
img1 = F.vflip(img1)
img2 = F.vflip(img2)
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.vflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class RandomResizeAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different results (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8):
super().__init__()
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.stretch_prob = stretch_prob
self.resize_prob = 0.8
self.max_stretch = 0.2
def forward(self, img1, img2, flow, valid_flow_mask):
# randomly sample scale
h, w = img1.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale_x = scale
scale_y = scale
if torch.rand(1) < self.stretch_prob:
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
scale_x = max(scale_x, min_scale)
scale_y = max(scale_y, min_scale)
new_h, new_w = round(h * scale_y), round(w * scale_x)
if torch.rand(1).item() < self.resize_prob:
# rescale the images
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the OF models with antialias=True?
img1 = F.resize(img1, size=(new_h, new_w), antialias=False)
img2 = F.resize(img2, size=(new_h, new_w), antialias=False)
if valid_flow_mask is None:
flow = F.resize(flow, size=(new_h, new_w))
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
else:
flow, valid_flow_mask = self._resize_sparse_flow(
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y
)
# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure if it matters much
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item()
img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1])
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1])
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1])
if valid_flow_mask is not None:
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1])
return img1, img2, flow, valid_flow_mask
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
h, w = flow.shape[-2:]
h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]
valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new[0] *= scale_x
valid_flow_new[1] *= scale_y
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = 1
return flow_new, valid_new
class Compose(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def forward(self, img1, img2, flow, valid_flow_mask):
for t in self.transforms:
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask)
return img1, img2, flow, valid_flow_mask
import datetime
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
import torch.nn.functional as F
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):
epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
flow_norm = (flow_gt**2).sum(dim=1).sqrt()
if valid_flow_mask is not None:
epe = epe[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
relative_epe = epe / flow_norm
metrics = {
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
"f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100,
}
return metrics, epe.numel()
def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400):
"""Loss function defined over sequence of flow predictions"""
if gamma > 1:
raise ValueError(f"Gamma should be < 1, got {gamma}.")
# exclude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
valid_flow_mask = valid_flow_mask[:, None, :, :]
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device)
flow_loss = (abs_diff * weights).sum()
return flow_loss
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if all(key in os.environ for key in ("LOCAL_RANK", "RANK", "WORLD_SIZE")):
# if we're here, the script was called with torchrun. Otherwise,
# these args will be set already by the run_with_submitit script
args.local_rank = int(os.environ["LOCAL_RANK"])
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
elif "gpu" in args:
# if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu
else:
print("Not using distributed mode!")
args.distributed = False
args.world_size = 1
return
args.distributed = True
_redefine_print(is_main=(args.rank == 0))
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
torch.distributed.barrier()
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
# Semantic segmentation reference training scripts
This folder contains reference training scripts for semantic segmentation.
They serve as a log of how to train specific models, as provide baseline
They serve as a log of how to train specific models and provide baseline
training and evaluation scripts to quickly bootstrap research.
All models have been trained on 8x V100 GPUs.
......@@ -14,30 +14,30 @@ You must modify the following flags:
## fcn_resnet50
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
## fcn_resnet101
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
```
## deeplabv3_resnet50
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
## deeplabv3_resnet101
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
```
## deeplabv3_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```
## lraspp_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment