Unverified Commit 01ffb3ae authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add RAFT model for optical flow (#5022)

parent 9b57de6c
......@@ -7,7 +7,7 @@ Models and pre-trained weights
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection and video classification.
keypoint detection, video classification, and optical flow.
.. note ::
Backward compatibility is guaranteed for loading a serialized
......@@ -798,3 +798,16 @@ ResNet (2+1)D
:template: function.rst
torchvision.models.video.r2plus1d_18
Optical flow
============
Raft
----
.. autosummary::
:toctree: generated/
:template: function.rst
torchvision.models.optical_flow.raft_large
torchvision.models.optical_flow.raft_small
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -93,7 +93,7 @@ def _get_expected_file(name=None):
return expected_file
def _assert_expected(output, name, prec):
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
......@@ -110,10 +110,11 @@ def _assert_expected(output, name, prec):
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError(f"The output for {filename}, is larger than 50kb")
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
else:
expected = torch.load(expected_file)
rtol = atol = prec
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
......@@ -818,5 +819,33 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
@needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
@pytest.mark.parametrize("scripted", (False, True))
def test_raft(model_builder, scripted):
torch.manual_seed(0)
# We need very small images, otherwise the pickle size would exceed the 50KB
# As a resut we need to override the correlation pyramid to not downsample
# too much, otherwise we would get nan values (effective H and W would be
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
model = model_builder(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)
bs = 1
img1 = torch.rand(bs, 3, 80, 72).cuda()
img2 = torch.rand(bs, 3, 80, 72).cuda()
preds = model(img1, img2)
flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)
if __name__ == "__main__":
pytest.main([__file__])
......@@ -12,6 +12,7 @@ from .efficientnet import *
from .regnet import *
from . import detection
from . import feature_extraction
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
from .raft import RAFT, raft_large, raft_small
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
h, w = img.shape[-2:]
xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (w - 1) - 1
ygrid = 2 * ygrid / (h - 1) - 1
normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
def make_coords_grid(batch_size: int, h: int, w: int):
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)
def upsample_flow(flow, up_mask: Optional[Tensor] = None):
"""Upsample flow by a factor of 8.
If up_mask is None we just interpolate.
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
"""
batch_size, _, h, w = flow.shape
new_h, new_w = h * 8, w * 8
if up_mask is None:
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w)
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w)
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment