Unverified Commit 849d02bc authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add pretrained weights on Chairs and Things for raft_large (#5060)

parent d416b2c2
# Optical flow reference training scripts
This folder contains reference training scripts for optical flow.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### RAFT Large
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_chairs \
--model raft_large \
--train-dataset chairs \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--num-steps 100000 \
--output-dir $chairs_dir
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model raft_large \
--train-dataset things \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--num-steps 100000 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
```
### Evaluation
```
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
```
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
final pass of Sintel. Results may vary slightly depending on the batch size and
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
```
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
```
......@@ -3,10 +3,16 @@ import warnings
from pathlib import Path
import torch
import torchvision.models.optical_flow
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small
try:
from torchvision.prototype import models as PM
from torchvision.prototype.models import optical_flow as PMOF
except ImportError:
PM = PMOF = None
def get_train_dataset(stage, dataset_root):
......@@ -125,6 +131,13 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
def validate(model, args):
val_datasets = args.val_dataset or []
if args.weights:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = OpticalFlowPresetEval()
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
......@@ -134,14 +147,14 @@ def validate(model, args):
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_validate(
model,
......@@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
def main(args):
utils.setup_ddp(args)
model = raft_small() if args.small else raft_large()
if args.weights:
model = PMOF.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
......@@ -306,7 +323,12 @@ def get_args_parser(add_help=True):
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
parser.add_argument(
"--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
)
# TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument(
"--num_flow_updates",
......
......@@ -91,7 +91,8 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"]
......@@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn):
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
"segmentation": ["categories", "mIoU", "acc"],
"video": classification_fields,
"optical_flow": [],
}
module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name])
......@@ -201,13 +203,18 @@ def test_old_vs_new_factory(model_fn, dev):
if module_name == "detection":
x = [x]
if module_name == "optical_flow":
args = [x, x] # RAFT model requires img1, img2 as input
else:
args = [x]
# compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)
def test_smoke():
......
......@@ -8,6 +8,7 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops import ConvNormActivation
from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once
from ._utils import grid_sample, make_coords_grid, upsample_flow
......@@ -19,6 +20,9 @@ __all__ = (
)
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""
......@@ -474,8 +478,8 @@ class RAFT(nn.Module):
hidden_state = torch.tanh(hidden_state)
context = F.relu(context)
coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda()
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
flow_predictions = []
for _ in range(num_flow_updates):
......@@ -496,6 +500,9 @@ class RAFT(nn.Module):
def _raft(
*,
arch=None,
pretrained=False,
progress=False,
# Feature encoder
feature_encoder_layers,
feature_encoder_block,
......@@ -560,7 +567,7 @@ def _raft(
multiplier=0.25, # See comment in MaskPredictor about this
)
return RAFT(
model = RAFT(
feature_encoder=feature_encoder,
context_encoder=context_encoder,
corr_block=corr_block,
......@@ -568,6 +575,11 @@ def _raft(
mask_predictor=mask_predictor,
**kwargs, # not really needed, all params should be consumed by now
)
if pretrained:
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def raft_large(*, pretrained=False, progress=True, **kwargs):
......@@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""
if pretrained:
raise ValueError("No checkpoint is available for raft_large")
return _raft(
arch="raft_large",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
......@@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
nn.Module: The model.
"""
if pretrained:
raise ValueError("No checkpoint is available for raft_small")
return _raft(
arch="raft_small",
pretrained=pretrained,
progress=progress,
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
......
......@@ -4,12 +4,11 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
# from torchvision.prototype.transforms import RaftEval
from torchvision.prototype.transforms import RaftEval
from torchvision.transforms.functional import InterpolationMode
from .._api import WeightsEnum
# from .._api import Weights
from .._api import Weights
from .._utils import handle_legacy_interface
......@@ -22,17 +21,33 @@ __all__ = (
)
_COMMON_META = {"interpolation": InterpolationMode.BILINEAR}
class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411,
"sintel_train_finalpass_epe": 2.7894,
},
)
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=RaftEval,
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161,
},
)
# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
......@@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum):
# },
# )
# default = C_T_V1
default = C_T_V2
class Raft_Small_Weights(WeightsEnum):
......@@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum):
# default = C_T_V1
@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
weights(Raft_Large_weights, optional): pretrained weights to use.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
......@@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
weights = Raft_Large_Weights.verify(weights)
return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
......@@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
......@@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
weights = Raft_Small_Weights.verify(weights)
return _raft(
model = _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
......@@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
use_mask_predictor=False,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment