Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
from typing import List, Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid
def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
"""Function to create a 2D Gaussian kernel."""
x = torch.arange(kernel_size, dtype=torch.float32)
y = torch.arange(kernel_size, dtype=torch.float32)
x = x - (kernel_size - 1) / 2
y = y - (kernel_size - 1) / 2
x, y = torch.meshgrid(x, y)
grid = (x**2 + y**2) / (2 * sigma**2)
kernel = torch.exp(-grid)
kernel = kernel / kernel.sum()
return kernel
def _sequence_loss_fn(
flow_preds: List[Tensor],
flow_gt: Tensor,
valid_flow_mask: Optional[Tensor],
gamma: Tensor,
max_flow: int = 256,
exclude_large: bool = False,
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
torch._assert(
gamma < 1,
"sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
)
if exclude_large:
# exclude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
else:
valid_flow_mask = flow_norm < max_flow
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
if valid_flow_mask is not None:
abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
# allocating on CPU and moving to device during run-time can force
# an unwanted GPU synchronization that produces a large overhead
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class SequenceLoss(nn.Module):
def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
"""
Args:
gamma: value for the exponential weighting of the loss across frames
max_flow: maximum flow value to exclude
exclude_large_flows: whether to exclude large flows
"""
super().__init__()
self.max_flow = max_flow
self.excluding_large = exclude_large_flows
self.register_buffer("gamma", torch.tensor([gamma]))
# cache the scale factor for the loss
self._weights = None
def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of flow predictions of shape (batch_size, C, H, W)
flow_gt: ground truth flow of shape (batch_size, C, H, W)
valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
"""
loss, weights = _sequence_loss_fn(
flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _ssim_loss_fn(
source: Tensor,
reference: Tensor,
kernel: Tensor,
eps: float = 1e-8,
c1: float = 0.01**2,
c2: float = 0.03**2,
use_padding: bool = False,
) -> Tensor:
# ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
# ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim
torch._assert(
source.ndim == reference.ndim == 4,
"SSIM: `source` and `reference` must be 4-dimensional tensors",
)
torch._assert(
source.shape == reference.shape,
"SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
source.shape, reference.shape
),
)
B, C, H, W = source.shape
kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
if use_padding:
pad_size = kernel.shape[2] // 2
source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")
mu1 = F.conv2d(source, kernel, groups=C)
mu2 = F.conv2d(reference, kernel, groups=C)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)
sigma1_sq = mu_img1_sq - mu1_sq
sigma2_sq = mu_img2_sq - mu2_sq
sigma12 = mu_img1_mu2 - mu1_mu2
numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim = numerator / (denominator + eps)
# doing 1 - ssim because we want to maximize the ssim
return 1 - ssim.mean(dim=(1, 2, 3))
class SSIM(nn.Module):
def __init__(
self,
kernel_size: int = 11,
max_val: float = 1.0,
sigma: float = 1.5,
eps: float = 1e-12,
use_padding: bool = True,
) -> None:
"""SSIM loss function.
Args:
kernel_size: size of the Gaussian kernel
max_val: constant scaling factor
sigma: sigma of the Gaussian kernel
eps: constant for division by zero
use_padding: whether to pad the input tensor such that we have a score for each pixel
"""
super().__init__()
self.kernel_size = kernel_size
self.max_val = max_val
self.sigma = sigma
gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
self.register_buffer("gaussian_kernel", gaussian_kernel)
self.c1 = (0.01 * self.max_val) ** 2
self.c2 = (0.03 * self.max_val) ** 2
self.use_padding = use_padding
self.eps = eps
def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
"""
Args:
source: source image of shape (batch_size, C, H, W)
reference: reference image of shape (batch_size, C, H, W)
Returns:
SSIM loss of shape (batch_size,)
"""
return _ssim_loss_fn(
source,
reference,
kernel=self.gaussian_kernel,
c1=self.c1,
c2=self.c2,
use_padding=self.use_padding,
eps=self.eps,
)
def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
# ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202
torch._assert(
img_gx.ndim >= 3,
"smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
)
torch._assert(
img_gx.ndim == val_gx.ndim,
"smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
img_gx.ndim, val_gx.ndim
),
)
for idx in range(img_gx.ndim):
torch._assert(
(img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
"smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
img_gx.shape, val_gx.shape
),
)
# -3 is channel dimension
weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))
smoothness_x = img_gx * weights_x
smoothness_y = img_gy * weights_y
smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
return smoothness
class SmoothnessLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def _x_gradient(self, img: Tensor) -> Tensor:
if img.ndim > 4:
original_shape = img.shape
is_reshaped = True
img = img.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
grad = padded[..., :, :-1] - padded[..., :, 1:]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim > 4:
original_shape = x.shape
is_reshaped = True
x = x.reshape(-1, *original_shape[-3:])
else:
is_reshaped = False
padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
grad = padded[..., :-1, :] - padded[..., 1:, :]
if is_reshaped:
grad = grad.reshape(original_shape)
return grad
def forward(self, images: Tensor, vals: Tensor) -> Tensor:
"""
Args:
images: tensor of shape (D1, D2, ..., DN, C, H, W)
vals: tensor of shape (D1, D2, ..., DN, 1, H, W)
Returns:
smoothness loss of shape (D1, D2, ..., DN)
"""
img_gx = self._x_gradient(images)
img_gy = self._y_gradient(images)
val_gx = self._x_gradient(vals)
val_gy = self._y_gradient(vals)
return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)
def _flow_sequence_consistency_loss_fn(
flow_preds: List[Tensor],
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
weights: Optional[Tensor] = None,
):
"""Loss function defined over sequence of flow predictions"""
# Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
# In the original paper, an additional refinement network is used to refine a flow prediction.
# Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
# which should be consistent with the previous step. In this implementation, we simplify the overall loss
# term and ignore left-right consistency loss or photometric loss which can be treated separately.
torch._assert(
rescale_factor <= 1.0,
"sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
rescale_factor
),
)
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
N, B, C, H, W = flow_preds.shape
# rescale flow predictions to account for bilinear upsampling artifacts
if rescale_factor:
flow_preds = (
F.interpolate(
flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
)
) * rescale_factor
flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)
# force the next prediction to be similar to the previous prediction
abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0] - 1 # because we are comparing differences
if weights is None or len(weights) != num_predictions:
weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
flow_loss = (abs_diff * weights).sum()
return flow_loss, weights
class FlowSequenceConsistencyLoss(nn.Module):
def __init__(
self,
gamma: float = 0.8,
resize_factor: float = 0.25,
rescale_factor: float = 0.25,
rescale_mode: str = "bilinear",
) -> None:
super().__init__()
self.gamma = gamma
self.resize_factor = resize_factor
self.rescale_factor = rescale_factor
self.rescale_mode = rescale_mode
self._weights = None
def forward(self, flow_preds: List[Tensor]) -> Tensor:
"""
Args:
flow_preds: list of tensors of shape (batch_size, C, H, W)
Returns:
sequence consistency loss of shape (batch_size,)
"""
loss, weights = _flow_sequence_consistency_loss_fn(
flow_preds,
gamma=self.gamma,
resize_factor=self.resize_factor,
rescale_factor=self.rescale_factor,
rescale_mode=self.rescale_mode,
weights=self._weights,
)
self._weights = weights
return loss
def set_gamma(self, gamma: float) -> None:
self.gamma.fill_(gamma)
# reset the cached scale factor
self._weights = None
def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
torch._assert(
source.shape == target.shape,
"psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
)
# ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))
class PSNRLoss(nn.Module):
def __init__(self, max_val: float = 256) -> None:
"""
Args:
max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.
"""
super().__init__()
self.max_val = max_val
def forward(self, source: Tensor, target: Tensor) -> Tensor:
"""
Args:
source: tensor of shape (D1, D2, ..., DN, C, H, W)
target: tensor of shape (D1, D2, ..., DN, C, H, W)
Returns:
psnr loss of shape (D1, D2, ..., DN)
"""
# multiply by -1 as we want to maximize the psnr
return -1 * _psnr_loss_fn(source, target, self.max_val)
class FlowPhotoMetricLoss(nn.Module):
def __init__(
self,
ssim_weight: float = 0.85,
ssim_window_size: int = 11,
ssim_max_val: float = 1.0,
ssim_sigma: float = 1.5,
ssim_eps: float = 1e-12,
ssim_use_padding: bool = True,
max_displacement_ratio: float = 0.15,
) -> None:
super().__init__()
self._ssim_loss = SSIM(
kernel_size=ssim_window_size,
max_val=ssim_max_val,
sigma=ssim_sigma,
eps=ssim_eps,
use_padding=ssim_use_padding,
)
self._L1_weight = 1 - ssim_weight
self._SSIM_weight = ssim_weight
self._max_displacement_ratio = max_displacement_ratio
def forward(
self,
source: Tensor,
reference: Tensor,
flow_pred: Tensor,
valid_mask: Optional[Tensor] = None,
):
"""
Args:
source: tensor of shape (B, C, H, W)
reference: tensor of shape (B, C, H, W)
flow_pred: tensor of shape (B, 2, H, W)
valid_mask: tensor of shape (B, H, W) or None
Returns:
photometric loss of shape
"""
torch._assert(
source.ndim == 4,
"FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
)
torch._assert(
reference.ndim == source.ndim,
"FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
source.ndim, reference.ndim
),
)
torch._assert(
flow_pred.shape[1] == 2,
"FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
)
torch._assert(
flow_pred.ndim == 4,
"FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
)
B, C, H, W = source.shape
flow_channels = flow_pred.shape[1]
max_displacements = []
for dim in range(flow_channels):
shape_index = -1 - dim
max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))
# mask out all pixels that have larger flow than the max flow allowed
max_flow_mask = torch.logical_and(
*[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
)
if valid_mask is not None:
valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
else:
valid_mask = max_flow_mask.unsqueeze(1)
grid = make_coords_grid(B, H, W, device=str(source.device))
resampled_grids = grid - flow_pred
resampled_grids = resampled_grids.permute(0, 2, 3, 1)
resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")
# compute SSIM loss
ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss
return loss.mean()
from typing import Dict, List, Optional, Tuple
from torch import Tensor
AVAILABLE_METRICS = ["mae", "rmse", "epe", "bad1", "bad2", "epe", "1px", "3px", "5px", "fl-all", "relepe"]
def compute_metrics(
flow_pred: Tensor, flow_gt: Tensor, valid_flow_mask: Optional[Tensor], metrics: List[str]
) -> Tuple[Dict[str, float], int]:
for m in metrics:
if m not in AVAILABLE_METRICS:
raise ValueError(f"Invalid metric: {m}. Valid metrics are: {AVAILABLE_METRICS}")
metrics_dict = {}
pixels_diffs = (flow_pred - flow_gt).abs()
# there is no Y flow in Stereo Matching, therefore flow.abs() = flow.pow(2).sum(dim=1).sqrt()
flow_norm = flow_gt.abs()
if valid_flow_mask is not None:
valid_flow_mask = valid_flow_mask.unsqueeze(1)
pixels_diffs = pixels_diffs[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
num_pixels = pixels_diffs.numel()
if "bad1" in metrics:
metrics_dict["bad1"] = (pixels_diffs > 1).float().mean().item()
if "bad2" in metrics:
metrics_dict["bad2"] = (pixels_diffs > 2).float().mean().item()
if "mae" in metrics:
metrics_dict["mae"] = pixels_diffs.mean().item()
if "rmse" in metrics:
metrics_dict["rmse"] = pixels_diffs.pow(2).mean().sqrt().item()
if "epe" in metrics:
metrics_dict["epe"] = pixels_diffs.mean().item()
if "1px" in metrics:
metrics_dict["1px"] = (pixels_diffs < 1).float().mean().item()
if "3px" in metrics:
metrics_dict["3px"] = (pixels_diffs < 3).float().mean().item()
if "5px" in metrics:
metrics_dict["5px"] = (pixels_diffs < 5).float().mean().item()
if "fl-all" in metrics:
metrics_dict["fl-all"] = ((pixels_diffs < 3) & ((pixels_diffs / flow_norm) < 0.05)).float().mean().item() * 100
if "relepe" in metrics:
metrics_dict["relepe"] = (pixels_diffs / flow_norm).mean().item()
return metrics_dict, num_pixels
import torch
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
def unfreeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.train()
import torch.nn.functional as F
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
import os
from typing import List
import numpy as np
import torch
from torch import Tensor
from torchvision.utils import make_grid
@torch.no_grad()
def make_disparity_image(disparity: Tensor):
# normalize image to [0, 1]
disparity = disparity.detach().cpu()
disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
return disparity
@torch.no_grad()
def make_disparity_image_pairs(disparity: Tensor, image: Tensor):
disparity = make_disparity_image(disparity)
# image is in [-1, 1], bring it to [0, 1]
image = image.detach().cpu()
image = image * 0.5 + 0.5
return disparity, image
@torch.no_grad()
def make_disparity_sequence(disparities: List[Tensor]):
# convert each disparity to [0, 1]
for idx, disparity_batch in enumerate(disparities):
disparities[idx] = torch.stack(list(map(make_disparity_image, disparity_batch)))
# make the list into a batch
disparity_sequences = torch.stack(disparities)
return disparity_sequences
@torch.no_grad()
def make_pair_grid(*inputs, orientation="horizontal"):
# make a grid of images with the outputs and references side by side
if orientation == "horizontal":
# interleave the outputs and references
canvas = torch.zeros_like(inputs[0])
canvas = torch.cat([canvas] * len(inputs), dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs), padding=16, normalize=True, scale_each=True)
elif orientation == "vertical":
# interleave the outputs and references
canvas = torch.cat(inputs, dim=0)
size = len(inputs)
for idx, inp in enumerate(inputs):
canvas[idx::size, ...] = inp
grid = make_grid(canvas, nrow=len(inputs[0]), padding=16, normalize=True, scale_each=True)
else:
raise ValueError("Unknown orientation: {}".format(orientation))
return grid
@torch.no_grad()
def make_training_sample_grid(
left_images: Tensor,
right_images: Tensor,
disparities: Tensor,
masks: Tensor,
predictions: List[Tensor],
) -> np.ndarray:
# detach images and renormalize to [0, 1]
images_left = left_images.detach().cpu() * 0.5 + 0.5
images_right = right_images.detach().cpu() * 0.5 + 0.5
# detach the disparties and predictions
disparities = disparities.detach().cpu()
predictions = predictions[-1].detach().cpu()
# keep only the first channel of pixels, and repeat it 3 times
disparities = disparities[:, :1, ...].repeat(1, 3, 1, 1)
predictions = predictions[:, :1, ...].repeat(1, 3, 1, 1)
# unsqueeze and repeat the masks
masks = masks.detach().cpu().unsqueeze(1).repeat(1, 3, 1, 1)
# make a grid that will self normalize across the batch
pred_grid = make_pair_grid(images_left, images_right, masks, disparities, predictions, orientation="horizontal")
pred_grid = pred_grid.permute(1, 2, 0).numpy()
pred_grid = (pred_grid * 255).astype(np.uint8)
return pred_grid
@torch.no_grad()
def make_disparity_sequence_grid(predictions: List[Tensor], disparities: Tensor) -> np.ndarray:
# right most we will be adding the ground truth
seq_len = len(predictions) + 1
predictions = list(map(lambda x: x[:, :1, :, :].detach().cpu(), predictions + [disparities]))
sequence = make_disparity_sequence(predictions)
# swap axes to have the in the correct order for each batch sample
sequence = torch.swapaxes(sequence, 0, 1).contiguous().reshape(-1, 1, disparities.shape[-2], disparities.shape[-1])
sequence = make_grid(sequence, nrow=seq_len, padding=16, normalize=True, scale_each=True)
sequence = sequence.permute(1, 2, 0).numpy()
sequence = (sequence * 255).astype(np.uint8)
return sequence
@torch.no_grad()
def make_prediction_image_side_to_side(
predictions: Tensor, disparities: Tensor, valid_mask: Tensor, save_path: str, prefix: str
) -> None:
import matplotlib.pyplot as plt
# normalize the predictions and disparities in [0, 1]
predictions = (predictions - predictions.min()) / (predictions.max() - predictions.min())
disparities = (disparities - disparities.min()) / (disparities.max() - disparities.min())
predictions = predictions * valid_mask
disparities = disparities * valid_mask
predictions = predictions.detach().cpu()
disparities = disparities.detach().cpu()
for idx, (pred, gt) in enumerate(zip(predictions, disparities)):
pred = pred.permute(1, 2, 0).numpy()
gt = gt.permute(1, 2, 0).numpy()
# plot pred and gt side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pred)
ax[0].set_title("Prediction")
ax[1].imshow(gt)
ax[1].set_title("Ground Truth")
save_name = os.path.join(save_path, "{}_{}.png".format(prefix, idx))
plt.savefig(save_name)
plt.close()
...@@ -22,43 +22,50 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs. ...@@ -22,43 +22,50 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs.
### Faster R-CNN ResNet-50 FPN ### Faster R-CNN ResNet-50 FPN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
### Faster R-CNN MobileNetV3-Large FPN ### Faster R-CNN MobileNetV3-Large FPN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
``` ```
### Faster R-CNN MobileNetV3-Large 320 FPN ### Faster R-CNN MobileNetV3-Large 320 FPN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```
### FCOS ResNet-50 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fcos_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
### RetinaNet ### RetinaNet
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
### SSD300 VGG16 ### SSD300 VGG16
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\ --dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd --weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES
``` ```
### SSDlite320 MobileNetV3-Large ### SSDlite320 MobileNetV3-Large
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
--weight-decay 0.00004 --data-augmentation ssdlite --weight-decay 0.00004 --data-augmentation ssdlite
...@@ -67,16 +74,15 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ...@@ -67,16 +74,15 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
### Mask R-CNN ### Mask R-CNN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
### Keypoint R-CNN ### Keypoint R-CNN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ torchrun --nproc_per_node=8 train.py\
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
--lr-steps 36 43 --aspect-ratio-group-factor 3 --lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
import json
import tempfile
import numpy as np
import copy import copy
import time import io
import torch from contextlib import redirect_stdout
import torch._six
from pycocotools.cocoeval import COCOeval import numpy as np
from pycocotools.coco import COCO
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
import torch
from collections import defaultdict
import utils import utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class CocoEvaluator(object): class CocoEvaluator:
def __init__(self, coco_gt, iou_types): def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple)) if not isinstance(iou_types, (list, tuple)):
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
coco_gt = copy.deepcopy(coco_gt) coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt self.coco_gt = coco_gt
...@@ -36,7 +31,8 @@ class CocoEvaluator(object): ...@@ -36,7 +31,8 @@ class CocoEvaluator(object):
for iou_type in self.iou_types: for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type) results = self.prepare(predictions, iou_type)
coco_dt = loadRes(self.coco_gt, results) if results else COCO() with redirect_stdout(io.StringIO()):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type] coco_eval = self.coco_eval[iou_type]
coco_eval.cocoDt = coco_dt coco_eval.cocoDt = coco_dt
...@@ -56,18 +52,17 @@ class CocoEvaluator(object): ...@@ -56,18 +52,17 @@ class CocoEvaluator(object):
def summarize(self): def summarize(self):
for iou_type, coco_eval in self.coco_eval.items(): for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type)) print(f"IoU metric: {iou_type}")
coco_eval.summarize() coco_eval.summarize()
def prepare(self, predictions, iou_type): def prepare(self, predictions, iou_type):
if iou_type == "bbox": if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions) return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm": if iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions) return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints": if iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions) return self.prepare_for_coco_keypoint(predictions)
else: raise ValueError(f"Unknown iou type {iou_type}")
raise ValueError("Unknown iou type {}".format(iou_type))
def prepare_for_coco_detection(self, predictions): def prepare_for_coco_detection(self, predictions):
coco_results = [] coco_results = []
...@@ -109,8 +104,7 @@ class CocoEvaluator(object): ...@@ -109,8 +104,7 @@ class CocoEvaluator(object):
labels = prediction["labels"].tolist() labels = prediction["labels"].tolist()
rles = [ rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
for mask in masks
] ]
for rle in rles: for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8") rle["counts"] = rle["counts"].decode("utf-8")
...@@ -146,7 +140,7 @@ class CocoEvaluator(object): ...@@ -146,7 +140,7 @@ class CocoEvaluator(object):
{ {
"image_id": original_id, "image_id": original_id,
"category_id": labels[k], "category_id": labels[k],
'keypoints': keypoint, "keypoints": keypoint,
"score": scores[k], "score": scores[k],
} }
for k, keypoint in enumerate(keypoints) for k, keypoint in enumerate(keypoints)
...@@ -192,161 +186,7 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs): ...@@ -192,161 +186,7 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
coco_eval._paramsEval = copy.deepcopy(coco_eval.params) coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
################################################################# def evaluate(imgs):
# From pycocotools, just removed the prints and fixed with redirect_stdout(io.StringIO()):
# a Python3 bug about unicode not defined imgs.evaluate()
################################################################# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
# Ideally, pycocotools wouldn't have hard-coded prints
# so that we could avoid copy-pasting those two functions
def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann
if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])
# print('index created!')
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
maskUtils = mask_util
def loadRes(self, resFile):
"""
Load result file and return a result api object.
Args:
self (obj): coco object with ground truth annotations
resFile (str): file name of result file
Returns:
res (obj): result api object
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]
# print('Loading and preparing results...')
# tic = time.time()
if isinstance(resFile, torch._six.string_classes):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, 'results in not an array of objects'
annsImgIds = [ann['image_id'] for ann in anns]
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
'Results do not correspond to current coco set'
if 'caption' in anns[0]:
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
for id, ann in enumerate(anns):
ann['id'] = id + 1
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
bb = ann['bbox']
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
if 'segmentation' not in ann:
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann['area'] = bb[2] * bb[3]
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'segmentation' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann['area'] = maskUtils.area(ann['segmentation'])
if 'bbox' not in ann:
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
ann['area'] = (x2 - x1) * (y2 - y1)
ann['id'] = id + 1
ann['bbox'] = [x1, y1, x2 - x1, y2 - y1]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
res.dataset['annotations'] = anns
createIndex(res)
return res
def evaluate(self):
'''
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs
#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
import copy
import os import os
from PIL import Image
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import transforms as T
from pycocotools import mask as coco_mask from pycocotools import mask as coco_mask
from pycocotools.coco import COCO from pycocotools.coco import COCO
import transforms as T
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_coco_poly_to_mask(segmentations, height, width): def convert_coco_poly_to_mask(segmentations, height, width):
masks = [] masks = []
...@@ -47,16 +25,15 @@ def convert_coco_poly_to_mask(segmentations, height, width): ...@@ -47,16 +25,15 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks return masks
class ConvertCocoPolysToMask(object): class ConvertCocoPolysToMask:
def __call__(self, image, target): def __call__(self, image, target):
w, h = image.size w, h = image.size
image_id = target["image_id"] image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"] anno = target["annotations"]
anno = [obj for obj in anno if obj['iscrowd'] == 0] anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno] boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing # guard against no boxes via resizing
...@@ -119,7 +96,7 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None): ...@@ -119,7 +96,7 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
# if all boxes have close to zero area, there is no annotation # if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno): if _has_only_empty_bbox(anno):
return False return False
# keypoints task have a slight different critera for considering # keypoints task have a slight different criteria for considering
# if an annotation is valid # if an annotation is valid
if "keypoints" not in anno[0]: if "keypoints" not in anno[0]:
return True return True
...@@ -129,7 +106,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None): ...@@ -129,7 +106,6 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return True return True
return False return False
assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = [] ids = []
for ds_idx, img_id in enumerate(dataset.ids): for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
...@@ -147,55 +123,56 @@ def convert_to_coco_api(ds): ...@@ -147,55 +123,56 @@ def convert_to_coco_api(ds):
coco_ds = COCO() coco_ds = COCO()
# annotation IDs need to start at 1, not 0, see torchvision issue #1530 # annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1 ann_id = 1
dataset = {'images': [], 'categories': [], 'annotations': []} dataset = {"images": [], "categories": [], "annotations": []}
categories = set() categories = set()
for img_idx in range(len(ds)): for img_idx in range(len(ds)):
# find better way to get target # find better way to get target
# targets = ds.get_annotations(img_idx) # targets = ds.get_annotations(img_idx)
img, targets = ds[img_idx] img, targets = ds[img_idx]
image_id = targets["image_id"].item() image_id = targets["image_id"]
img_dict = {} img_dict = {}
img_dict['id'] = image_id img_dict["id"] = image_id
img_dict['height'] = img.shape[-2] img_dict["height"] = img.shape[-2]
img_dict['width'] = img.shape[-1] img_dict["width"] = img.shape[-1]
dataset['images'].append(img_dict) dataset["images"].append(img_dict)
bboxes = targets["boxes"] bboxes = targets["boxes"].clone()
bboxes[:, 2:] -= bboxes[:, :2] bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist() bboxes = bboxes.tolist()
labels = targets['labels'].tolist() labels = targets["labels"].tolist()
areas = targets['area'].tolist() areas = targets["area"].tolist()
iscrowd = targets['iscrowd'].tolist() iscrowd = targets["iscrowd"].tolist()
if 'masks' in targets: if "masks" in targets:
masks = targets['masks'] masks = targets["masks"]
# make masks Fortran contiguous for coco_mask # make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if 'keypoints' in targets: if "keypoints" in targets:
keypoints = targets['keypoints'] keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes) num_objs = len(bboxes)
for i in range(num_objs): for i in range(num_objs):
ann = {} ann = {}
ann['image_id'] = image_id ann["image_id"] = image_id
ann['bbox'] = bboxes[i] ann["bbox"] = bboxes[i]
ann['category_id'] = labels[i] ann["category_id"] = labels[i]
categories.add(labels[i]) categories.add(labels[i])
ann['area'] = areas[i] ann["area"] = areas[i]
ann['iscrowd'] = iscrowd[i] ann["iscrowd"] = iscrowd[i]
ann['id'] = ann_id ann["id"] = ann_id
if 'masks' in targets: if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy()) ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if 'keypoints' in targets: if "keypoints" in targets:
ann['keypoints'] = keypoints[i] ann["keypoints"] = keypoints[i]
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset['annotations'].append(ann) dataset["annotations"].append(ann)
ann_id += 1 ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)] dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset coco_ds.dataset = dataset
coco_ds.createIndex() coco_ds.createIndex()
return coco_ds return coco_ds
def get_coco_api_from_dataset(dataset): def get_coco_api_from_dataset(dataset):
# FIXME: This is... awful?
for _ in range(10): for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection): if isinstance(dataset, torchvision.datasets.CocoDetection):
break break
...@@ -208,11 +185,11 @@ def get_coco_api_from_dataset(dataset): ...@@ -208,11 +185,11 @@ def get_coco_api_from_dataset(dataset):
class CocoDetection(torchvision.datasets.CocoDetection): class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms): def __init__(self, img_folder, ann_file, transforms):
super(CocoDetection, self).__init__(img_folder, ann_file) super().__init__(img_folder, ann_file)
self._transforms = transforms self._transforms = transforms
def __getitem__(self, idx): def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx) img, target = super().__getitem__(idx)
image_id = self.ids[idx] image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target) target = dict(image_id=image_id, annotations=target)
if self._transforms is not None: if self._transforms is not None:
...@@ -220,7 +197,7 @@ class CocoDetection(torchvision.datasets.CocoDetection): ...@@ -220,7 +197,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target return img, target
def get_coco(root, image_set, transforms, mode='instances'): def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
anno_file_template = "{}_{}2017.json" anno_file_template = "{}_{}2017.json"
PATHS = { PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
...@@ -228,17 +205,26 @@ def get_coco(root, image_set, transforms, mode='instances'): ...@@ -228,17 +205,26 @@ def get_coco(root, image_set, transforms, mode='instances'):
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
} }
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
img_folder, ann_file = PATHS[image_set] img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder) img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file) ann_file = os.path.join(root, ann_file)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms) if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
target_keys += ["masks"]
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
else:
# TODO: handle with_masks for V1?
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train": if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset) dataset = _coco_remove_images_without_annotations(dataset)
...@@ -246,7 +232,3 @@ def get_coco(root, image_set, transforms, mode='instances'): ...@@ -246,7 +232,3 @@ def get_coco(root, image_set, transforms, mode='instances'):
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
return dataset return dataset
def get_coco_kp(root, image_set, transforms):
return get_coco(root, image_set, transforms, mode="person_keypoints")
import math import math
import sys import sys
import time import time
import torch
import torch
import torchvision.models.detection.mask_rcnn import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils import utils
from coco_eval import CocoEvaluator
from coco_utils import get_coco_api_from_dataset
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = 'Epoch: [{}]'.format(epoch) header = f"Epoch: [{epoch}]"
lr_scheduler = None lr_scheduler = None
if epoch == 0: if epoch == 0:
warmup_factor = 1. / 1000 warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1) warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
for images, targets in metric_logger.log_every(data_loader, print_freq, header): for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images) images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets) loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes # reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced = utils.reduce_dict(loss_dict)
...@@ -38,13 +38,18 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): ...@@ -38,13 +38,18 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value = losses_reduced.item() loss_value = losses_reduced.item()
if not math.isfinite(loss_value): if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value)) print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced) print(loss_dict_reduced)
sys.exit(1) sys.exit(1)
optimizer.zero_grad() optimizer.zero_grad()
losses.backward() if scaler is not None:
optimizer.step() scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward()
optimizer.step()
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
...@@ -67,7 +72,7 @@ def _get_iou_types(model): ...@@ -67,7 +72,7 @@ def _get_iou_types(model):
return iou_types return iou_types
@torch.no_grad() @torch.inference_mode()
def evaluate(model, data_loader, device): def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads() n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU # FIXME remove this and make paste_masks_in_image run on the GPU
...@@ -75,7 +80,7 @@ def evaluate(model, data_loader, device): ...@@ -75,7 +80,7 @@ def evaluate(model, data_loader, device):
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset) coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model) iou_types = _get_iou_types(model)
...@@ -92,7 +97,7 @@ def evaluate(model, data_loader, device): ...@@ -92,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time model_time = time.time() - model_time
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time() evaluator_time = time.time()
coco_evaluator.update(res) coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time evaluator_time = time.time() - evaluator_time
......
import bisect import bisect
from collections import defaultdict
import copy import copy
from itertools import repeat, chain
import math import math
import numpy as np from collections import defaultdict
from itertools import chain, repeat
import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
import torchvision import torchvision
from PIL import Image from PIL import Image
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.model_zoo import tqdm
def _repeat_to_at_least(iterable, n): def _repeat_to_at_least(iterable, n):
...@@ -34,12 +33,10 @@ class GroupedBatchSampler(BatchSampler): ...@@ -34,12 +33,10 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups). 0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch. batch_size (int): Size of mini-batch.
""" """
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
self.batch_size = batch_size self.batch_size = batch_size
...@@ -66,10 +63,9 @@ class GroupedBatchSampler(BatchSampler): ...@@ -66,10 +63,9 @@ class GroupedBatchSampler(BatchSampler):
expected_num_batches = len(self) expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches num_remaining = expected_num_batches - num_batches
if num_remaining > 0: if num_remaining > 0:
# for the remaining batches, take first the buffers with largest number # for the remaining batches, take first the buffers with the largest number
# of elements # of elements
for group_id, _ in sorted(buffer_per_group.items(), for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id]) remaining = self.batch_size - len(buffer_per_group[group_id])
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining) samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
...@@ -85,10 +81,12 @@ class GroupedBatchSampler(BatchSampler): ...@@ -85,10 +81,12 @@ class GroupedBatchSampler(BatchSampler):
def _compute_aspect_ratios_slow(dataset, indices=None): def _compute_aspect_ratios_slow(dataset, indices=None):
print("Your dataset doesn't support the fast path for " print(
"computing the aspect ratios, so will iterate over " "Your dataset doesn't support the fast path for "
"the full dataset and load every image instead. " "computing the aspect ratios, so will iterate over "
"This might take some time...") "the full dataset and load every image instead. "
"This might take some time..."
)
if indices is None: if indices is None:
indices = range(len(dataset)) indices = range(len(dataset))
...@@ -104,9 +102,12 @@ def _compute_aspect_ratios_slow(dataset, indices=None): ...@@ -104,9 +102,12 @@ def _compute_aspect_ratios_slow(dataset, indices=None):
sampler = SubsetSampler(indices) sampler = SubsetSampler(indices)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, sampler=sampler, dataset,
batch_size=1,
sampler=sampler,
num_workers=14, # you might want to increase it for faster processing num_workers=14, # you might want to increase it for faster processing
collate_fn=lambda x: x[0]) collate_fn=lambda x: x[0],
)
aspect_ratios = [] aspect_ratios = []
with tqdm(total=len(dataset)) as pbar: with tqdm(total=len(dataset)) as pbar:
for _i, (img, _) in enumerate(data_loader): for _i, (img, _) in enumerate(data_loader):
...@@ -190,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0): ...@@ -190,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group # count number of elements per group
counts = np.unique(groups, return_counts=True)[1] counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf] fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins)) print(f"Using {fbins} as bins for aspect ratio quantization")
print("Count of instances per bin: {}".format(counts)) print(f"Count of instances per bin: {counts}")
return groups return groups
import transforms as T from collections import defaultdict
import torch
import transforms as reference_transforms
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
import torchvision.tv_tensors
return torchvision.transforms.v2, torchvision.tv_tensors
else:
return reference_transforms, None
class DetectionPresetTrain: class DetectionPresetTrain:
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): # Note: this transform assumes that the input to forward() are always PIL
if data_augmentation == 'hflip': # images, regardless of the backend parameter.
self.transforms = T.Compose([ def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):
T, tv_tensors = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "tv_tensor":
transforms.append(T.ToImage())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if data_augmentation == "hflip":
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), ]
]) elif data_augmentation == "ssd":
elif data_augmentation == 'ssd': fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
self.transforms = T.Compose([ transforms += [
T.RandomPhotometricDistort(), T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)), T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(), T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), ]
]) elif data_augmentation == "ssdlite":
elif data_augmentation == 'ssdlite': transforms += [
self.transforms = T.Compose([
T.RandomIoUCrop(), T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), ]
])
else: else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(),
T.ToPureTensor(),
]
self.transforms = T.Compose(transforms)
def __call__(self, img, target): def __call__(self, img, target):
return self.transforms(img, target) return self.transforms(img, target)
class DetectionPresetEval: class DetectionPresetEval:
def __init__(self): def __init__(self, backend="pil", use_v2=False):
self.transforms = T.ToTensor() T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "tv_tensor":
transforms += [T.ToImage()]
else:
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [T.ToPureTensor()]
self.transforms = T.Compose(transforms)
def __call__(self, img, target): def __call__(self, img, target):
return self.transforms(img, target) return self.transforms(img, target)
...@@ -21,74 +21,125 @@ import datetime ...@@ -21,74 +21,125 @@ import datetime
import os import os
import time import time
import presets
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import torchvision.models.detection import torchvision.models.detection
import torchvision.models.detection.mask_rcnn import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco, get_coco_kp
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate
import presets
import utils import utils
from coco_utils import get_coco
from engine import evaluate, train_one_epoch
def get_dataset(name, image_set, transform, data_path): from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
paths = { from torchvision.transforms import InterpolationMode
"coco": (data_path, get_coco, 91), from transforms import SimpleCopyPaste
"coco_kp": (data_path, get_coco_kp, 2)
}
p, ds_fn, num_classes = paths[name] def copypaste_collate_fn(batch):
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
ds = ds_fn(p, image_set=image_set, transforms=transform) return copypaste(*utils.collate_fn(batch))
def get_dataset(is_train, args):
image_set = "train" if is_train else "val"
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
with_masks = "mask" in args.model
ds = get_coco(
root=args.data_path,
image_set=image_set,
transforms=get_transform(is_train, args),
mode=mode,
use_v2=args.use_v2,
with_masks=with_masks,
)
return ds, num_classes return ds, num_classes
def get_transform(train, data_augmentation): def get_transform(is_train, args):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval() if is_train:
return presets.DetectionPresetTrain(
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
return lambda img, target: (trans(img), target)
else:
return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
import argparse import argparse
parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help)
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
parser.add_argument('--dataset', default='coco', help='dataset') parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') parser.add_argument(
parser.add_argument('--device', default='cuda', help='device') "--dataset",
parser.add_argument('-b', '--batch-size', default=2, type=int, default="coco",
help='images per gpu, the total batch size is $NGPU x batch_size') type=str,
parser.add_argument('--epochs', default=26, type=int, metavar='N', help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
help='number of total epochs to run') )
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
help='number of data loading workers (default: 4)') parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument('--lr', default=0.02, type=float, parser.add_argument(
help='initial learning rate, 0.02 is the default value for training ' "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
'on 8 gpus and 2 images_per_gpu') )
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
help='momentum') parser.add_argument(
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
metavar='W', help='weight decay (default: 1e-4)', )
dest='weight_decay') parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)') parser.add_argument(
parser.add_argument('--lr-step-size', default=8, type=int, "--lr",
help='decrease lr every step-size epochs (multisteplr scheduler only)') default=0.02,
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, type=float,
help='decrease lr every step-size epochs (multisteplr scheduler only)') help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
parser.add_argument('--lr-gamma', default=0.1, type=float, )
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)') parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument('--print-freq', default=20, type=int, help='print frequency') parser.add_argument(
parser.add_argument('--output-dir', default='.', help='path where to save') "--wd",
parser.add_argument('--resume', default='', help='resume from checkpoint') "--weight-decay",
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') default=1e-4,
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) type=float,
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn') metavar="W",
parser.add_argument('--trainable-backbone-layers', default=None, type=int, help="weight decay (default: 1e-4)",
help='number of trainable layers of backbone') dest="weight_decay",
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)') )
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
)
parser.add_argument(
"--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
)
parser.add_argument(
"--lr-steps",
default=[16, 22],
nargs="+",
type=int,
help="decrease lr every step-size epochs (multisteplr scheduler only)",
)
parser.add_argument(
"--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
)
parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
parser.add_argument(
"--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
)
parser.add_argument(
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
)
parser.add_argument( parser.add_argument(
"--sync-bn", "--sync-bn",
dest="sync_bn", dest="sync_bn",
...@@ -101,22 +152,43 @@ def get_args_parser(add_help=True): ...@@ -101,22 +152,43 @@ def get_args_parser(add_help=True):
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--pretrained", "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
) )
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
help='number of distributed processes') parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# Use CopyPaste augmentation training parameter
parser.add_argument(
"--use-copypaste",
action="store_true",
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser return parser
def main(args): def main(args):
if args.backend.lower() == "tv_tensor" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
if args.dataset == "coco_kp" and args.use_v2:
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -125,17 +197,19 @@ def main(args): ...@@ -125,17 +197,19 @@ def main(args):
device = torch.device(args.device) device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.use_deterministic_algorithms(True)
# Data loading code # Data loading code
print("Loading data") print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation), dataset, num_classes = get_dataset(is_train=True, args=args)
args.data_path) dataset_test, _ = get_dataset(is_train=False, args=args)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else: else:
train_sampler = torch.utils.data.RandomSampler(dataset) train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test) test_sampler = torch.utils.data.SequentialSampler(dataset_test)
...@@ -144,27 +218,33 @@ def main(args): ...@@ -144,27 +218,33 @@ def main(args):
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else: else:
train_batch_sampler = torch.utils.data.BatchSampler( train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
train_sampler, args.batch_size, drop_last=True)
train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
train_collate_fn = copypaste_collate_fn
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
collate_fn=utils.collate_fn) )
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
sampler=test_sampler, num_workers=args.workers, )
collate_fn=utils.collate_fn)
print("Creating model") print("Creating model")
kwargs = { kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
"trainable_backbone_layers": args.trainable_backbone_layers if args.data_augmentation in ["multiscale", "lsj"]:
} kwargs["_skip_resize"] = True
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, model = torchvision.models.get_model(
**kwargs) args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -174,27 +254,50 @@ def main(args): ...@@ -174,27 +254,50 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module model_without_ddp = model.module
params = [p for p in model.parameters() if p.requires_grad] if args.norm_weight_decay is None:
optimizer = torch.optim.SGD( parameters = [p for p in model.parameters() if p.requires_grad]
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower() args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == 'multisteplr': if args.lr_scheduler == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
elif args.lr_scheduler == 'cosineannealinglr': elif args.lr_scheduler == "cosineannealinglr":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else: else:
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " raise RuntimeError(
"are supported.".format(args.lr_scheduler)) f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
)
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint['epoch'] + 1 args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only: if args.test_only:
torch.backends.cudnn.deterministic = True
evaluate(model, data_loader_test, device=device) evaluate(model, data_loader_test, device=device)
return return
...@@ -203,29 +306,27 @@ def main(args): ...@@ -203,29 +306,27 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
lr_scheduler.step() lr_scheduler.step()
if args.output_dir: if args.output_dir:
checkpoint = { checkpoint = {
'model': model_without_ddp.state_dict(), "model": model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), "optimizer": optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), "lr_scheduler": lr_scheduler.state_dict(),
'args': args, "args": args,
'epoch': epoch "epoch": epoch,
} }
utils.save_on_master( if args.amp:
checkpoint, checkpoint["scaler"] = scaler.state_dict()
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master( utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
# evaluate after every epoch # evaluate after every epoch
evaluate(model, data_loader_test, device=device) evaluate(model, data_loader_test, device=device)
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print(f"Training time {total_time_str}")
if __name__ == "__main__": if __name__ == "__main__":
......
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.transforms import functional as F from torchvision import ops
from torchvision.transforms import transforms as T from torchvision.transforms import functional as F, InterpolationMode, transforms as T
from typing import List, Tuple, Dict, Optional
def _flip_coco_person_keypoints(kps, width): def _flip_coco_person_keypoints(kps, width):
...@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width): ...@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return flipped_data return flipped_data
class Compose(object): class Compose:
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -28,12 +28,13 @@ class Compose(object): ...@@ -28,12 +28,13 @@ class Compose(object):
class RandomHorizontalFlip(T.RandomHorizontalFlip): class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
image = F.hflip(image) image = F.hflip(image)
if target is not None: if target is not None:
width, _ = F._get_image_size(image) _, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target: if "masks" in target:
target["masks"] = target["masks"].flip(-1) target["masks"] = target["masks"].flip(-1)
...@@ -44,16 +45,39 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip): ...@@ -44,16 +45,39 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
return image, target return image, target
class ToTensor(nn.Module): class PILToTensor(nn.Module):
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
image = F.to_tensor(image) ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target return image, target
class RandomIoUCrop(nn.Module): class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, def __init__(
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__() super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale self.min_scale = min_scale
...@@ -65,18 +89,19 @@ class RandomIoUCrop(nn.Module): ...@@ -65,18 +89,19 @@ class RandomIoUCrop(nn.Module):
self.options = sampler_options self.options = sampler_options
self.trials = trials self.trials = trials
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None: if target is None:
raise ValueError("The targets can't be None for this transform.") raise ValueError("The targets can't be None for this transform.")
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
orig_w, orig_h = F._get_image_size(image) _, orig_h, orig_w = F.get_dimensions(image)
while True: while True:
# sample an option # sample an option
...@@ -112,8 +137,9 @@ class RandomIoUCrop(nn.Module): ...@@ -112,8 +137,9 @@ class RandomIoUCrop(nn.Module):
# check at least 1 box with jaccard limitations # check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area] boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]], ious = torchvision.ops.boxes.box_iou(
dtype=boxes.dtype, device=boxes.device)) boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
)
if ious.max() < min_jaccard_overlap: if ious.max() < min_jaccard_overlap:
continue continue
...@@ -130,14 +156,16 @@ class RandomIoUCrop(nn.Module): ...@@ -130,14 +156,16 @@ class RandomIoUCrop(nn.Module):
class RandomZoomOut(nn.Module): class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5): def __init__(
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
):
super().__init__() super().__init__()
if fill is None: if fill is None:
fill = [0., 0., 0.] fill = [0.0, 0.0, 0.0]
self.fill = fill self.fill = fill
self.side_range = side_range self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range)) raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p self.p = p
@torch.jit.unused @torch.jit.unused
...@@ -146,18 +174,19 @@ class RandomZoomOut(nn.Module): ...@@ -146,18 +174,19 @@ class RandomZoomOut(nn.Module):
# We fake the type to make it work on JIT # We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0 return tuple(int(x) for x in self.fill) if is_pil else 0
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
if torch.rand(1) < self.p: if torch.rand(1) >= self.p:
return image, target return image, target
orig_w, orig_h = F._get_image_size(image) _, orig_h, orig_w = F.get_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -176,9 +205,11 @@ class RandomZoomOut(nn.Module): ...@@ -176,9 +205,11 @@ class RandomZoomOut(nn.Module):
image = F.pad(image, [left, top, right, bottom], fill=fill) image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \ image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
image[..., :, (left + orig_w):] = v ..., :, (left + orig_w) :
] = v
if target is not None: if target is not None:
target["boxes"][:, 0::2] += left target["boxes"][:, 0::2] += left
...@@ -188,8 +219,14 @@ class RandomZoomOut(nn.Module): ...@@ -188,8 +219,14 @@ class RandomZoomOut(nn.Module):
class RandomPhotometricDistort(nn.Module): class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5), def __init__(
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5): self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__() super().__init__()
self._brightness = T.ColorJitter(brightness=brightness) self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast) self._contrast = T.ColorJitter(contrast=contrast)
...@@ -197,11 +234,12 @@ class RandomPhotometricDistort(nn.Module): ...@@ -197,11 +234,12 @@ class RandomPhotometricDistort(nn.Module):
self._saturation = T.ColorJitter(saturation=saturation) self._saturation = T.ColorJitter(saturation=saturation)
self.p = p self.p = p
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
...@@ -226,14 +264,338 @@ class RandomPhotometricDistort(nn.Module): ...@@ -226,14 +264,338 @@ class RandomPhotometricDistort(nn.Module):
image = self._contrast(image) image = self._contrast(image)
if r[6] < self.p: if r[6] < self.p:
channels = F._get_image_num_channels(image) channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels) permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image) is_pil = F._is_pil_image(image)
if is_pil: if is_pil:
image = F.to_tensor(image) image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
image = image[..., permutation, :, :] image = image[..., permutation, :, :]
if is_pil: if is_pil:
image = F.to_pil_image(image) image = F.to_pil_image(image)
return image, target return image, target
class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
_, orig_height, orig_width = F.get_dimensions(image)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
)
return image, target
class FixedSizeCrop(nn.Module):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__()
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode
def _pad(self, img, target, padding):
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
padding = [pad_left, pad_top, pad_right, pad_bottom]
img = F.pad(img, padding, self.fill, self.padding_mode)
if target is not None:
target["boxes"][:, 0::2] += pad_left
target["boxes"][:, 1::2] += pad_top
if "masks" in target:
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
return img, target
def _crop(self, img, target, top, left, height, width):
img = F.crop(img, top, left, height, width)
if target is not None:
boxes = target["boxes"]
boxes[:, 0::2] -= left
boxes[:, 1::2] -= top
boxes[:, 0::2].clamp_(min=0, max=width)
boxes[:, 1::2].clamp_(min=0, max=height)
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
target["boxes"] = boxes[is_valid]
target["labels"] = target["labels"][is_valid]
if "masks" in target:
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
return img, target
def forward(self, img, target=None):
_, height, width = F.get_dimensions(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
img, target = self._crop(img, target, top, left, new_height, new_width)
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
return img, target
class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
_, orig_height, orig_width = F.get_dimensions(image)
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)
return image, target
def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target
# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)
paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
return image, out_target
class SimpleCopyPaste(torch.nn.Module):
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending
def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)
return output_images, output_targets
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s
from collections import defaultdict, deque
import datetime import datetime
import errno import errno
import os import os
import time import time
from collections import defaultdict, deque
import torch import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
...@@ -32,7 +32,7 @@ class SmoothedValue(object): ...@@ -32,7 +32,7 @@ class SmoothedValue(object):
""" """
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
t = t.tolist() t = t.tolist()
...@@ -63,11 +63,8 @@ class SmoothedValue(object): ...@@ -63,11 +63,8 @@ class SmoothedValue(object):
def __str__(self): def __str__(self):
return self.fmt.format( return self.fmt.format(
median=self.median, median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
avg=self.avg, )
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data): def all_gather(data):
...@@ -98,7 +95,7 @@ def reduce_dict(input_dict, average=True): ...@@ -98,7 +95,7 @@ def reduce_dict(input_dict, average=True):
world_size = get_world_size() world_size = get_world_size()
if world_size < 2: if world_size < 2:
return input_dict return input_dict
with torch.no_grad(): with torch.inference_mode():
names = [] names = []
values = [] values = []
# sort the keys so that they are consistent across processes # sort the keys so that they are consistent across processes
...@@ -113,7 +110,7 @@ def reduce_dict(input_dict, average=True): ...@@ -113,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return reduced_dict return reduced_dict
class MetricLogger(object): class MetricLogger:
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
...@@ -130,15 +127,12 @@ class MetricLogger(object): ...@@ -130,15 +127,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format( raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
type(self).__name__, attr))
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append( loss_str.append(f"{name}: {str(meter)}")
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -151,31 +145,28 @@ class MetricLogger(object): ...@@ -151,31 +145,28 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None): def log_every(self, iterable, print_freq, header=None):
i = 0 i = 0
if not header: if not header:
header = '' header = ""
start_time = time.time() start_time = time.time()
end = time.time() end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}') iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available(): if torch.cuda.is_available():
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [
'[{0' + space_fmt + '}/{1}]', header,
'eta: {eta}', "[{0" + space_fmt + "}/{1}]",
'{meters}', "eta: {eta}",
'time: {time}', "{meters}",
'data: {data}', "time: {time}",
'max mem: {memory:.0f}' "data: {data}",
]) "max mem: {memory:.0f}",
]
)
else: else:
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
'[{0' + space_fmt + '}/{1}]', )
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0 MB = 1024.0 * 1024.0
for obj in iterable: for obj in iterable:
data_time.update(time.time() - end) data_time.update(time.time() - end)
...@@ -185,39 +176,34 @@ class MetricLogger(object): ...@@ -185,39 +176,34 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available(): if torch.cuda.is_available():
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i,
time=str(iter_time), data=str(data_time), len(iterable),
memory=torch.cuda.max_memory_allocated() / MB)) eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else: else:
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
time=str(iter_time), data=str(data_time))) )
)
i += 1 i += 1
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format( print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
header, total_time_str, total_time / len(iterable)))
def collate_fn(batch): def collate_fn(batch):
return tuple(zip(*batch)) return tuple(zip(*batch))
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
def mkdir(path): def mkdir(path):
try: try:
os.makedirs(path) os.makedirs(path)
...@@ -231,10 +217,11 @@ def setup_for_distributed(is_master): ...@@ -231,10 +217,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process This function disables printing when not in master process
""" """
import builtins as __builtin__ import builtins as __builtin__
builtin_print = __builtin__.print builtin_print = __builtin__.print
def print(*args, **kwargs): def print(*args, **kwargs):
force = kwargs.pop('force', False) force = kwargs.pop("force", False)
if is_master or force: if is_master or force:
builtin_print(*args, **kwargs) builtin_print(*args, **kwargs)
...@@ -271,25 +258,25 @@ def save_on_master(*args, **kwargs): ...@@ -271,25 +258,25 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args): def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE']) args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ['LOCAL_RANK']) args.gpu = int(os.environ["LOCAL_RANK"])
elif 'SLURM_PROCID' in os.environ: elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ['SLURM_PROCID']) args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count() args.gpu = args.rank % torch.cuda.device_count()
else: else:
print('Not using distributed mode') print("Not using distributed mode")
args.distributed = False args.distributed = False
return return
args.distributed = True args.distributed = True
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl' args.dist_backend = "nccl"
print('| distributed init (rank {}): {}'.format( print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
world_size=args.world_size, rank=args.rank) )
torch.distributed.barrier() torch.distributed.barrier()
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)
# Optical flow reference training scripts
This folder contains reference training scripts for optical flow.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### RAFT Large
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_chairs \
--model raft_large \
--train-dataset chairs \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--epochs 72 \
--output-dir $chairs_dir
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model raft_large \
--train-dataset things \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--epochs 20 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
```
### Evaluation
```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2
```
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
final pass of Sintel-train. Results may vary slightly depending on the batch
size and the number of GPUs. For the most accurate results use 1 GPU and
`--batch-size 1`:
```
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
```
You can also evaluate on Kitti train:
```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2
Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679
```
import torch
import transforms as T
class OpticalFlowPresetEval(torch.nn.Module):
def __init__(self):
super().__init__()
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.ValidateModelInput(),
]
)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
class OpticalFlowPresetTrain(torch.nn.Module):
def __init__(
self,
*,
# RandomResizeAndCrop params
crop_size,
min_scale=-0.2,
max_scale=0.5,
stretch_prob=0.8,
# AsymmetricColorJitter params
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5 / 3.14,
# Random[H,V]Flip params
asymmetric_jitter_prob=0.2,
do_flip=True,
):
super().__init__()
transforms = [
T.PILToTensor(),
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.RandomResizeAndCrop(
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
),
]
if do_flip:
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
transforms += [
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.RandomErasing(max_erase=2),
T.MakeValidFlowMask(),
T.ValidateModelInput(),
]
self.transforms = T.Compose(transforms)
def forward(self, img1, img2, flow, valid):
return self.transforms(img1, img2, flow, valid)
import argparse
import warnings
from math import ceil
from pathlib import Path
import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")
@torch.no_grad()
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
device = torch.device(args.device)
model.eval()
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.workers,
)
num_flow_updates = num_flow_updates or args.num_flow_updates
def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched, so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.to(device), image2.to(device)
padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()
print(header, logger)
def evaluate(model, args):
val_datasets = args.val_dataset or []
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(img1, img2, flow, valid_flow_mask):
img1, img2 = trans(img1, img2)
if flow is not None and not isinstance(flow, torch.Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
else:
preprocessing = OpticalFlowPresetEval()
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes, so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_evaluate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
metrics.pop("f1")
logger.update(loss=loss, **metrics)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
scheduler.step()
def main(args):
utils.setup_ddp(args)
args.test_only = args.train_dataset is None
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
model = torchvision.models.get_model(args.model, weights=args.weights)
if args.distributed:
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
if args.resume is not None:
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
if args.test_only:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, args)
return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0
torch.backends.cudnn.benchmark = True
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
logger = utils.MetricLogger()
done = False
for epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(epoch)
train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
args=args,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {epoch} done. ", logger)
if not args.distributed or args.rank == 0:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if epoch % args.val_freq == 0 or done:
evaluate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)
parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume and weights should be in an exclusive arg group
parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask):
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")
if img1.shape != img2.shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = img1.shape[-2:]
if flow is not None and flow.shape != (2, h, w):
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
if valid_flow_mask is not None:
if valid_flow_mask.shape != (h, w):
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
if valid_flow_mask.dtype != torch.bool:
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")
return img1, img2, flow, valid_flow_mask
class MakeValidFlowMask(torch.nn.Module):
# This transform generates a valid_flow_mask if it doesn't exist.
# The flow is considered valid if ||flow||_inf < threshold
# This is a noop for Kitti and HD1K which already come with a built-in flow mask.
def __init__(self, threshold=1000):
super().__init__()
self.threshold = threshold
def forward(self, img1, img2, flow, valid_flow_mask):
if flow is not None and valid_flow_mask is None:
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0)
return img1, img2, flow, valid_flow_mask
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.convert_image_dtype(img1, dtype=self.dtype)
img2 = F.convert_image_dtype(img2, dtype=self.dtype)
img1 = img1.contiguous()
img2 = img2.contiguous()
return img1, img2, flow, valid_flow_mask
class Normalize(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean
self.std = std
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.normalize(img1, mean=self.mean, std=self.std)
img2 = F.normalize(img2, mean=self.mean, std=self.std)
return img1, img2, flow, valid_flow_mask
class PILToTensor(torch.nn.Module):
# Converts all inputs to tensors
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming
# for consistency with the rest, e.g. the segmentation reference.
def forward(self, img1, img2, flow, valid_flow_mask):
img1 = F.pil_to_tensor(img1)
img2 = F.pil_to_tensor(img2)
if flow is not None:
flow = torch.from_numpy(flow)
if valid_flow_mask is not None:
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class AsymmetricColorJitter(T.ColorJitter):
# p determines the proba of doing asymmertric vs symmetric color jittering
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img1 = super().forward(img1)
img2 = super().forward(img2)
else:
# symmetric: same transform for img1 and img2
batch = torch.stack([img1, img2])
batch = super().forward(batch)
img1, img2 = batch[0], batch[1]
return img1, img2, flow, valid_flow_mask
class RandomErasing(T.RandomErasing):
# This only erases img2, and with an extra max_erase param
# This max_erase is needed because in the RAFT training ref does:
# 0 erasing with .5 proba
# 1 erase with .25 proba
# 2 erase with .25 proba
# and there's no accurate way to achieve this otherwise.
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase
if self.max_erase <= 0:
raise ValueError("max_raise should be greater than 0")
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value])
img2 = F.erase(img2, x, y, h, w, v, self.inplace)
return img1, img2, flow, valid_flow_mask
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
img1 = F.hflip(img1)
img2 = F.hflip(img2)
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.hflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class RandomVerticalFlip(T.RandomVerticalFlip):
def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
return img1, img2, flow, valid_flow_mask
img1 = F.vflip(img1)
img2 = F.vflip(img2)
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None]
if valid_flow_mask is not None:
valid_flow_mask = F.vflip(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
class RandomResizeAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different results (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8):
super().__init__()
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.stretch_prob = stretch_prob
self.resize_prob = 0.8
self.max_stretch = 0.2
def forward(self, img1, img2, flow, valid_flow_mask):
# randomly sample scale
h, w = img1.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale_x = scale
scale_y = scale
if torch.rand(1) < self.stretch_prob:
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
scale_x = max(scale_x, min_scale)
scale_y = max(scale_y, min_scale)
new_h, new_w = round(h * scale_y), round(w * scale_x)
if torch.rand(1).item() < self.resize_prob:
# rescale the images
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the OF models with antialias=True?
img1 = F.resize(img1, size=(new_h, new_w), antialias=False)
img2 = F.resize(img2, size=(new_h, new_w), antialias=False)
if valid_flow_mask is None:
flow = F.resize(flow, size=(new_h, new_w))
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
else:
flow, valid_flow_mask = self._resize_sparse_flow(
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y
)
# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure if it matters much
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item()
img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1])
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1])
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1])
if valid_flow_mask is not None:
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1])
return img1, img2, flow, valid_flow_mask
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
h, w = flow.shape[-2:]
h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]
valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new[0] *= scale_x
valid_flow_new[1] *= scale_y
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = 1
return flow_new, valid_new
class Compose(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def forward(self, img1, img2, flow, valid_flow_mask):
for t in self.transforms:
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask)
return img1, img2, flow, valid_flow_mask
import datetime
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
import torch.nn.functional as F
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):
epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
flow_norm = (flow_gt**2).sum(dim=1).sqrt()
if valid_flow_mask is not None:
epe = epe[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
relative_epe = epe / flow_norm
metrics = {
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
"f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100,
}
return metrics, epe.numel()
def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400):
"""Loss function defined over sequence of flow predictions"""
if gamma > 1:
raise ValueError(f"Gamma should be < 1, got {gamma}.")
# exclude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
valid_flow_mask = valid_flow_mask[:, None, :, :]
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device)
flow_loss = (abs_diff * weights).sum()
return flow_loss
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if all(key in os.environ for key in ("LOCAL_RANK", "RANK", "WORLD_SIZE")):
# if we're here, the script was called with torchrun. Otherwise,
# these args will be set already by the run_with_submitit script
args.local_rank = int(os.environ["LOCAL_RANK"])
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
elif "gpu" in args:
# if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu
else:
print("Not using distributed mode!")
args.distributed = False
args.world_size = 1
return
args.distributed = True
_redefine_print(is_main=(args.rank == 0))
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
torch.distributed.barrier()
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
# Semantic segmentation reference training scripts # Semantic segmentation reference training scripts
This folder contains reference training scripts for semantic segmentation. This folder contains reference training scripts for semantic segmentation.
They serve as a log of how to train specific models, as provide baseline They serve as a log of how to train specific models and provide baseline
training and evaluation scripts to quickly bootstrap research. training and evaluation scripts to quickly bootstrap research.
All models have been trained on 8x V100 GPUs. All models have been trained on 8x V100 GPUs.
...@@ -14,30 +14,30 @@ You must modify the following flags: ...@@ -14,30 +14,30 @@ You must modify the following flags:
## fcn_resnet50 ## fcn_resnet50
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
## fcn_resnet101 ## fcn_resnet101
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_resnet50 ## deeplabv3_resnet50
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_resnet101 ## deeplabv3_resnet101
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_mobilenet_v3_large ## deeplabv3_mobilenet_v3_large
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
``` ```
## lraspp_mobilenet_v3_large ## lraspp_mobilenet_v3_large
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
``` ```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment