"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "8bf46d4e32fb7ef44895c451b5cd2548e3badf30"
Unverified Commit 4dd8b5cc authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add training reference for optical flow models (#5027)

parent 47bd9620
import argparse
import warnings
from pathlib import Path
import torch
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small
def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")
@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
model.eval()
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
num_flow_updates = num_flow_updates or args.num_flow_updates
def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.cuda(), image2.cuda()
padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()
print(header, logger)
def validate(model, args):
val_datasets = args.val_dataset or []
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and args.rank == 0:
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
)
_validate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
metrics.pop("f1")
logger.update(loss=loss, **metrics)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
scheduler.step()
current_step += 1
if current_step == args.num_steps:
return True, current_step
return False, current_step
def main(args):
utils.setup_ddp(args)
model = raft_small() if args.small else raft_large()
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
if args.resume is not None:
d = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True)
if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
validate(model, args)
return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
torch.backends.cudnn.benchmark = True
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
total_steps=args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
logger = utils.MetricLogger()
done = False
current_epoch = current_step = 0
while not done:
print(f"EPOCH {current_epoch}")
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
done, current_step = train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
current_step=current_step,
args=args,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)
current_epoch += 1
if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
if current_epoch % args.val_freq == 0 or done:
validate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument(
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
)
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)
parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")
parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
parser.add_argument("--batch-size", type=int, default=6)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)
import datetime
import os
import time
from collections import defaultdict
from collections import deque
import torch
import torch.distributed as dist
import torch.nn.functional as F
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):
epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
flow_norm = (flow_gt ** 2).sum(dim=1).sqrt()
if valid_flow_mask is not None:
epe = epe[valid_flow_mask]
flow_norm = flow_norm[valid_flow_mask]
relative_epe = epe / flow_norm
metrics = {
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
"f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100,
}
return metrics, epe.numel()
def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400):
"""Loss function defined over sequence of flow predictions"""
if gamma > 1:
raise ValueError(f"Gamma should be < 1, got {gamma}.")
# exlude invalid pixels and extremely large diplacements
flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt()
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
valid_flow_mask = valid_flow_mask[:, None, :, :]
flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff = (flow_preds - flow_gt).abs()
abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4))
num_predictions = flow_preds.shape[0]
weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device)
flow_loss = (abs_diff * weights).sum()
return flow_loss
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == "sintel":
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if all(key in os.environ for key in ("LOCAL_RANK", "RANK", "WORLD_SIZE")):
# if we're here, the script was called with torchrun. Otherwise
# these args will be set already by the run_with_submitit script
args.local_rank = int(os.environ["LOCAL_RANK"])
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
elif "gpu" in args:
# if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu
else:
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.")
_redefine_print(is_main=(args.rank == 0))
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment