Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
from torchvision.transforms import autoaugment, transforms
import torch
from torchvision.transforms.functional import InterpolationMode
def get_module(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
return torchvision.transforms.v2
else:
import torchvision.transforms
return torchvision.transforms
class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
auto_augment_policy=None, random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
use_v2=False,
):
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
if auto_augment_policy == "ra":
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil":
transforms.append(T.PILToTensor())
transforms.extend(
[
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
transforms.append(T.RandomErasing(p=random_erase_prob))
if use_v2:
transforms.append(T.ToPureTensor())
self.transforms = transforms.Compose(trans)
self.transforms = T.Compose(transforms)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
def __init__(
self,
*,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
backend="pil",
use_v2=False,
):
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms += [
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]
if backend == "pil":
transforms.append(T.PILToTensor())
transforms += [
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
if use_v2:
transforms.append(T.ToPureTensor())
self.transforms = T.Compose(transforms)
def __call__(self, img):
return self.transforms(img)
import math
import torch
import torch.distributed as dist
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU).
Heavily based on 'torch.utils.data.DistributedSampler'.
This is borrowed from the DeiT Repo:
https://github.com/facebookresearch/deit/blob/main/samplers.py
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
self.seed = seed
self.repetitions = repetitions
def __iter__(self):
if self.shuffle:
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(self.repetitions)]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# Subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
import datetime
import os
import time
import warnings
import presets
import torch
import torch.utils.data
from torch import nn
import torchvision
import presets
import torchvision.transforms
import utils
try:
from apex import amp
except ImportError:
amp = None
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if scaler is not None:
scaler.scale(loss).backward()
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device, print_freq=100):
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
header = f"Test: {log_suffix}"
num_processed_samples = 0
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
......@@ -61,18 +77,34 @@ def evaluate(model, criterion, data_loader, device, print_freq=100):
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5))
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
......@@ -82,24 +114,43 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
val_resize_size, val_crop_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
)
interpolation = InterpolationMode(args.interpolation)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
print(f"Loading dataset_train from {cache_path}")
# TODO: this could probably be weights_only=True
dataset, _ = torch.load(cache_path, weights_only=False)
else:
# We need a default value for the variables below because args may come
# from train_quantization.py which doesn't define them.
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
ra_magnitude = getattr(args, "ra_magnitude", None)
augmix_severity = getattr(args, "augmix_severity", None)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
presets.ClassificationPresetTrain(
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
backend=args.backend,
use_v2=args.use_v2,
),
)
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
print(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
......@@ -108,21 +159,41 @@ def load_data(traindir, valdir, args):
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
print(f"Loading dataset_test from {cache_path}")
# TODO: this could probably be weights_only=True
dataset_test, _ = torch.load(cache_path, weights_only=False)
else:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
preprocessing = weights.transforms(antialias=True)
if args.backend == "tensor":
preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
else:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size,
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
use_v2=args.use_v2,
)
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
preprocessing,
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
......@@ -131,10 +202,6 @@ def load_data(traindir, valdir, args):
def main(args):
if args.apex and amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -143,58 +210,154 @@ def main(args):
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
num_classes = len(dataset.classes)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:
def collate_fn(batch):
return mixup_cutmix(*default_collate(batch))
else:
collate_fn = default_collate
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
custom_keys_weight_decay = []
if args.bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
)
opt_name = args.opt.lower()
if opt_name == 'sgd':
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else:
raise RuntimeError(
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
)
else:
lr_scheduler = main_lr_scheduler
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
model_ema = None
if args.model_ema:
# Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])
if scaler:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
......@@ -202,54 +365,94 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
if model_ema:
checkpoint["model_ema"] = model_ema.state_dict()
if scaler:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print(f"Training time {total_time_str}")
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
)
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--bias-weight-decay",
default=None,
type=float,
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
)
parser.add_argument(
"--transformer-embedding-decay",
default=None,
type=float,
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
parser.add_argument(
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
)
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
......@@ -268,29 +471,55 @@ def get_args_parser(add_help=True):
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
)
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument(
"--model-ema-decay",
type=float,
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser
......
import copy
import datetime
import os
import time
import copy
import torch
import torch.ao.quantization
import torch.utils.data
from torch import nn
import torchvision
import torch.quantization
import utils
from train import train_one_epoch, evaluate, load_data
from torch import nn
from train import evaluate, load_data, train_one_epoch
def main(args):
......@@ -20,51 +20,52 @@ def main(args):
print(args)
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed "
"on distributed mode")
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
torch.backends.quantized.engine = args.backend
if args.qbackend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
torch.backends.quantized.engine = args.qbackend
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code
print("Loading data")
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.eval_batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
prefix = "quantized_"
model_name = args.model
if not model_name.startswith(prefix):
model_name = prefix + model_name
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True)
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=args.lr_step_size,
gamma=args.lr_gamma)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
criterion = nn.CrossEntropyLoss()
model_without_ddp = model
......@@ -73,34 +74,31 @@ def main(args):
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.post_training_quantize:
# perform calibration on a subset of the training dataset
# for that, create a subset of the training dataset
ds = torch.utils.data.Subset(
dataset,
indices=list(range(args.batch_size * args.num_calibration_batches)))
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
data_loader_calibration = torch.utils.data.DataLoader(
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
torch.quantization.prepare(model, inplace=True)
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print('Saving quantized model')
print("Saving quantized model")
if utils.is_main_process():
torch.save(model.state_dict(), os.path.join(args.output_dir,
'quantized_post_train_model.pth'))
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
print("Evaluating post-training quantized model")
evaluate(model, criterion, data_loader_test, device=device)
return
......@@ -109,113 +107,111 @@ def main(args):
evaluate(model, criterion, data_loader_test, device=device)
return
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant)
model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.ao.quantization.enable_fake_quant)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
print('Starting training for epoch', epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
args.print_freq)
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step()
with torch.no_grad():
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch)
model.apply(torch.quantization.disable_observer)
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.ao.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch)
print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print('Evaluate QAT model')
print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device)
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu'))
torch.quantization.convert(quantized_eval_model, inplace=True)
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
print('Evaluate Quantized model')
evaluate(quantized_eval_model, criterion, data_loader_test,
device=torch.device('cpu'))
print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
model.train()
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'eval_model': quantized_eval_model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
print('Saving models after epoch ', epoch)
"model": model_without_ddp.state_dict(),
"eval_model": quantized_eval_model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
print("Saving models after epoch ", epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print(f"Training time {total_time_str}")
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help)
parser.add_argument('--data-path',
default='/datasets01/imagenet_full_size/061417/',
help='dataset')
parser.add_argument('--model',
default='mobilenet_v2',
help='model')
parser.add_argument('--backend',
default='qnnpack',
help='fbgemm or qnnpack')
parser.add_argument('--device',
default='cuda',
help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int,
help='batch size for calibration/training')
parser.add_argument('--eval-batch-size', default=128, type=int,
help='batch size for evaluation')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--num-observer-update-epochs',
default=4, type=int, metavar='N',
help='number of total epochs to update observers')
parser.add_argument('--num-batch-norm-update-epochs', default=3,
type=int, metavar='N',
help='number of total epochs to update batch norm stats')
parser.add_argument('--num-calibration-batches',
default=32, type=int, metavar='N',
help='number of batches of training set for \
observer calibration ')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr',
default=0.0001, type=float,
help='initial learning rate')
parser.add_argument('--momentum',
default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int,
help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int,
help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
)
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"--num-observer-update-epochs",
default=4,
type=int,
metavar="N",
help="number of total epochs to update observers",
)
parser.add_argument(
"--num-batch-norm-update-epochs",
default=3,
type=int,
metavar="N",
help="number of total epochs to update batch norm stats",
)
parser.add_argument(
"--num-calibration-batches",
default=32,
type=int,
metavar="N",
help="number of batches of training set for \
observer calibration ",
)
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
......@@ -243,15 +239,35 @@ def get_args_parser(add_help=True):
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url',
default='env://',
help='url used to set up distributed training')
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
if args.backend in ("fbgemm", "qnnpack"):
raise ValueError(
"The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
"instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
)
main(args)
import math
from typing import Tuple
import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
transforms_module = get_module(use_v2)
mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
if use_v2
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
)
if not mixup_cutmix:
return None
return transforms_module.RandomChoice(mixup_cutmix)
class RandomMixUp(torch.nn.Module):
"""Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
class RandomCutMix(torch.nn.Module):
"""Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError("Please provide a valid positive value for the num_classes.")
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
_, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
......@@ -32,11 +33,7 @@ class SmoothedValue(object):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
......@@ -65,14 +62,11 @@ class SmoothedValue(object):
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger(object):
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
......@@ -89,15 +83,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -110,31 +101,28 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
......@@ -144,28 +132,51 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
print(f"{header} Total time: {total_time_str}")
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu"):
def ema_avg(avg_model_param, model_param, num_averaged):
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=True)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
with torch.inference_mode():
maxk = max(topk)
batch_size = target.size(0)
if target.ndim == 2:
target = target.max(dim=1)[1]
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
......@@ -191,10 +202,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
......@@ -231,28 +243,29 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
args.dist_backend = "nccl"
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
......@@ -274,10 +287,7 @@ def average_checkpoints(inputs):
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
)
# Copies over the settings from the first checkpoint
if new_state is None:
......@@ -288,8 +298,7 @@ def average_checkpoints(inputs):
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
)
for k in params_keys:
p = model_params[k]
......@@ -311,7 +320,7 @@ def average_checkpoints(inputs):
return new_state
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True):
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
"""
This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces
......@@ -321,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
from torchvision import models as M
# Classification
model = M.mobilenet_v3_large(pretrained=False)
model = M.mobilenet_v3_large(weights=None)
print(store_model_weights(model, './class.pth'))
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
# Object Detection
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
print(store_model_weights(model, './obj.pth'))
# Segmentation
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
print(store_model_weights(model, './segm.pth', strict=False))
Args:
......@@ -355,12 +364,15 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
checkpoint_path = os.path.abspath(checkpoint_path)
output_dir = os.path.dirname(checkpoint_path)
# Deep copy to avoid side-effects on the model object.
# Deep copy to avoid side effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
# and remove unnecessary weights (such as auxiliaries, etc.)
if checkpoint_key == "model_ema":
del checkpoint[checkpoint_key]["n_averaged"]
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__()))
......@@ -377,3 +389,76 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
os.replace(tmp_path, output_path)
return output_path
def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def set_weight_decay(
model: torch.nn.Module,
weight_decay: float,
norm_weight_decay: Optional[float] = None,
norm_classes: Optional[List[type]] = None,
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
):
if not norm_classes:
norm_classes = [
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
torch.nn.GroupNorm,
torch.nn.modules.instancenorm._InstanceNorm,
torch.nn.LocalResponseNorm,
]
norm_classes = tuple(norm_classes)
params = {
"other": [],
"norm": [],
}
params_weight_decay = {
"other": weight_decay,
"norm": norm_weight_decay,
}
custom_keys = []
if custom_keys_weight_decay is not None:
for key, weight_decay in custom_keys_weight_decay:
params[key] = []
params_weight_decay[key] = weight_decay
custom_keys.append(key)
def _add_params(module, prefix=""):
for name, p in module.named_parameters(recurse=False):
if not p.requires_grad:
continue
is_custom_key = False
for key in custom_keys:
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
if key == target_name:
params[key].append(p)
is_custom_key = True
break
if not is_custom_key:
if norm_weight_decay is not None and isinstance(module, norm_classes):
params["norm"].append(p)
else:
params["other"].append(p)
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
_add_params(child_module, prefix=child_prefix)
_add_params(model)
param_groups = []
for key in params:
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups
# Stereo Matching reference training scripts
This folder contains reference training scripts for Stereo Matching.
They serve as a log of how to train specific models, so as to provide baseline
training and evaluation scripts to quickly bootstrap research.
### CREStereo
The CREStereo model was trained on a dataset mixture between **CREStereo**, **ETH3D** and the additional split from **Middlebury2014**.
A ratio of **88-6-6** was used in order to train a baseline weight set. We provide multi-set variant as well.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters loosely follow the recipe from https://github.com/megvii-research/CREStereo.
The original recipe trains for **300000** updates (or steps) on the dataset mixture. We modify the learning rate
schedule to one that starts decaying the weight much sooner. Throughout the experiments we found that this reduces
overfitting during evaluation time and gradient clip help stabilize the loss during a pre-mature learning rate change.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_cre \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other \
--dataset-steps 264000 18000 18000
--batch-size 2 \
--lr 0.0004 \
--min-lr 0.00002 \
--lr-decay-method cosine \
--warmup-steps 6000 \
--decay-after-steps 30000 \
--clip-grad-norm 1.0 \
```
We employ a multi-set fine-tuning stage where we uniformly sample from multiple datasets. Given hat some of these datasets have extremely large images (``2048x2048`` or more) we opt for a very aggressive scale-range ``[0.2 - 0.8]`` such that as much of the original frame composition is captured inside the ``384x512`` crop.
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
--dataset-root $dataset_root \
--name $name_things \
--model crestereo_base \
--train-datasets crestereo eth3d-train middlebury2014-other instereo2k fallingthings carla-highres sintel sceneflow-monkaa sceneflow-driving \
--dataset-steps 12000 12000 12000 12000 12000 12000 12000 12000 12000
--batch-size 2 \
--scale-range 0.2 0.8 \
--lr 0.0004 \
--lr-decay-method cosine \
--decay-after-steps 0 \
--warmup-steps 0 \
--min-lr 0.00002 \
--resume-path $checkpoint_dir/$name_cre.pth
```
### Evaluation
Evaluating the base weights
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_ETH_MBL_V1
```
This should give an **mae of about 1.416** on the train set of `Middlebury2014`. Results may vary slightly depending on the batch size and the number of GPUs. For the most accurate results use 1 GPU and `--batch-size 1`. The created log file should look like this, where the first key is the number of cascades and the nested key is the number of recursive iterations:
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 2.363, 'rmse': 4.352, '1px': 0.611, '3px': 0.828, '5px': 0.891, 'relepe': 0.176, 'fl-all': 64.511}
5: {'mae': 1.618, 'rmse': 3.71, '1px': 0.761, '3px': 0.879, '5px': 0.918, 'relepe': 0.154, 'fl-all': 77.128}
10: {'mae': 1.416, 'rmse': 3.53, '1px': 0.777, '3px': 0.896, '5px': 0.933, 'relepe': 0.148, 'fl-all': 78.388}
20: {'mae': 1.448, 'rmse': 3.583, '1px': 0.771, '3px': 0.893, '5px': 0.931, 'relepe': 0.145, 'fl-all': 77.7}
},
}
{
2: {
2: {'mae': 1.972, 'rmse': 4.125, '1px': 0.73, '3px': 0.865, '5px': 0.908, 'relepe': 0.169, 'fl-all': 74.396}
5: {'mae': 1.403, 'rmse': 3.448, '1px': 0.793, '3px': 0.905, '5px': 0.937, 'relepe': 0.151, 'fl-all': 80.186}
10: {'mae': 1.312, 'rmse': 3.368, '1px': 0.799, '3px': 0.912, '5px': 0.943, 'relepe': 0.148, 'fl-all': 80.379}
20: {'mae': 1.376, 'rmse': 3.542, '1px': 0.796, '3px': 0.91, '5px': 0.942, 'relepe': 0.149, 'fl-all': 80.054}
},
}
```
You can also evaluate the Finetuned weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.CRESTEREO_FINETUNE_MULTI_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.85, 'rmse': 3.797, '1px': 0.673, '3px': 0.862, '5px': 0.917, 'relepe': 0.171, 'fl-all': 69.736}
5: {'mae': 1.111, 'rmse': 3.166, '1px': 0.838, '3px': 0.93, '5px': 0.957, 'relepe': 0.134, 'fl-all': 84.596}
10: {'mae': 1.02, 'rmse': 3.073, '1px': 0.854, '3px': 0.938, '5px': 0.96, 'relepe': 0.129, 'fl-all': 86.042}
20: {'mae': 0.993, 'rmse': 3.059, '1px': 0.855, '3px': 0.942, '5px': 0.967, 'relepe': 0.126, 'fl-all': 85.784}
},
}
{
2: {
2: {'mae': 1.667, 'rmse': 3.867, '1px': 0.78, '3px': 0.891, '5px': 0.922, 'relepe': 0.165, 'fl-all': 78.89}
5: {'mae': 1.158, 'rmse': 3.278, '1px': 0.843, '3px': 0.926, '5px': 0.955, 'relepe': 0.135, 'fl-all': 84.556}
10: {'mae': 1.046, 'rmse': 3.13, '1px': 0.85, '3px': 0.934, '5px': 0.96, 'relepe': 0.13, 'fl-all': 85.464}
20: {'mae': 1.021, 'rmse': 3.102, '1px': 0.85, '3px': 0.935, '5px': 0.963, 'relepe': 0.129, 'fl-all': 85.417}
},
}
```
Evaluating the author provided weights:
```
torchrun --nproc_per_node 1 --nnodes 1 cascade_evaluation.py --dataset middlebury2014-train --batch-size 1 --dataset-root $dataset_root --model crestereo_base --weights CREStereo_Base_Weights.MEGVII_V1
```
```
Dataset: middlebury2014-train @size: [384, 512]:
{
1: {
2: {'mae': 1.704, 'rmse': 3.738, '1px': 0.738, '3px': 0.896, '5px': 0.933, 'relepe': 0.157, 'fl-all': 76.464}
5: {'mae': 0.956, 'rmse': 2.963, '1px': 0.88, '3px': 0.948, '5px': 0.965, 'relepe': 0.124, 'fl-all': 88.186}
10: {'mae': 0.792, 'rmse': 2.765, '1px': 0.905, '3px': 0.958, '5px': 0.97, 'relepe': 0.114, 'fl-all': 90.429}
20: {'mae': 0.749, 'rmse': 2.706, '1px': 0.907, '3px': 0.961, '5px': 0.972, 'relepe': 0.113, 'fl-all': 90.807}
},
}
{
2: {
2: {'mae': 1.702, 'rmse': 3.784, '1px': 0.784, '3px': 0.894, '5px': 0.924, 'relepe': 0.172, 'fl-all': 80.313}
5: {'mae': 0.932, 'rmse': 2.907, '1px': 0.877, '3px': 0.944, '5px': 0.963, 'relepe': 0.125, 'fl-all': 87.979}
10: {'mae': 0.773, 'rmse': 2.768, '1px': 0.901, '3px': 0.958, '5px': 0.972, 'relepe': 0.117, 'fl-all': 90.43}
20: {'mae': 0.854, 'rmse': 2.971, '1px': 0.9, '3px': 0.957, '5px': 0.97, 'relepe': 0.122, 'fl-all': 90.269}
},
}
```
# Concerns when training
We encourage users to be aware of the **aspect-ratio** and **disparity scale** they are targeting when doing any sort of training or fine-tuning. The model is highly sensitive to these two factors, as a consequence of naive multi-set fine-tuning one can achieve `0.2 mae` relatively fast. We recommend that users pay close attention to how they **balance dataset sizing** when training such networks.
Ideally, dataset scaling should be trated at an individual level and a thorough **EDA** of the disparity distribution in random crops at the desired training / inference size should be performed prior to any large compute investments.
### Disparity scaling
##### Sample A
The top row contains a sample from `Sintel` whereas the bottom row one from `Middlebury`.
![Disparity1](assets/disparity-domain-drift.jpg)
From left to right (`left_image`, `right_image`, `valid_mask`, `valid_mask & ground_truth`, `prediction`). **Darker is further away, lighter is closer**. In the case of `Sintel` which is more closely aligned to the original distribution of `CREStereo` we notice that the model accurately predicts the background scale whereas in the case of `Middlebury2014` it cannot correctly estimate the continuous disparity. Notice that the frame composition is similar for both examples. The blue skybox in the `Sintel` scene behaves similarly to the `Middlebury` black background. However, because the `Middlebury` samples comes from an extremely large scene the crop size of `384x512` does not correctly capture the general training distribution.
##### Sample B
The top row contains a scene from `Sceneflow` using the `Monkaa` split whilst the bottom row is a scene from `Middlebury`. This sample exhibits the same issues when it comes to **background estimation**. Given the exaggerated size of the `Middlebury` samples the model **colapses the smooth background** of the sample to what it considers to be a mean background disparity value.
![Disparity2](assets/disparity-background-mode-collapse.jpg)
For more detail on why this behaviour occurs based on the training distribution proportions you can read more about the network at: https://github.com/pytorch/vision/pull/6629#discussion_r978160493
### Metric overfitting
##### Learning is critical in the beginning
We also advise users to make user of faster training schedules, as the performance gain over long periods time is marginal. Here we exhibit a difference between a faster decay schedule and later decay schedule.
![Loss1](assets/Loss.jpg)
In **grey** we set the lr decay to begin after `30000` steps whilst in **orange** we opt for a very late learning rate decay at around `180000` steps. Although exhibiting stronger variance, we can notice that unfreezing the learning rate earlier whilst employing `gradient-norm` out-performs the default configuration.
##### Gradient norm saves time
![Loss2](assets/gradient-norm-removal.jpg)
In **grey** we keep ``gradient norm`` enabled whilst in **orange** we do not. We can notice that remvoing the gradient norm exacerbates the performance decrease in the early stages whilst also showcasing an almost complete collapse around the `60000` steps mark where we started decaying the lr for **orange**.
Although both runs ahieve an improvement of about ``0.1`` mae after the lr decay start, the benefits of it are observable much faster when ``gradient norm`` is employed as the recovery period is no longer accounted for.
import os
import warnings
import torch
import torchvision
import torchvision.prototype.models.depth.stereo
import utils
from torch.nn import functional as F
from train import make_eval_loader
from utils.metrics import AVAILABLE_METRICS
from visualization import make_prediction_image_side_to_side
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="which model to use if not speciffying a training checkpoint",
)
parser.add_argument("--img-folder", type=str, default="images")
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument("--workers", type=int, default=0, help="number of workers")
parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
def cascade_inference(model, image_left, image_right, iterations, cascades):
# check that image size is divisible by 16 * (2 ** (cascades - 1))
for image in [image_left, image_right]:
if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
raise ValueError(
f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
)
left_image_pyramid = [image_left]
right_image_pyramid = [image_right]
for idx in range(0, cascades - 1):
ds_factor = int(2 ** (idx + 1))
ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
0
)
flow_init = None
for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
# flow pred is a list
flow_init = flow_pred[-1]
return flow_init
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writer=None,
step=None,
iterations=10,
cascades=1,
batch_size=None,
header=None,
save_images=False,
save_path="",
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp.
"""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
batch_idx = 0
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
disp_pred = disp_pred[:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
if save_images:
if args.distributed:
rank_prefix = args.rank
else:
rank_prefix = 0
make_prediction_image_side_to_side(
disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
batch_idx += 1
num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
)
if writer is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writer.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
return logger_metrics
def evaluate(model, loader, args, writer=None, step=None):
os.makedirs(args.img_folder, exist_ok=True)
checkpoint_name = os.path.basename(args.checkpoint) or args.weights
image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
metrics = {}
base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
os.makedirs(base_image_folder, exist_ok=True)
for n_cascades in args.n_cascades:
for n_iters in args.n_iterations:
config = f"{n_cascades}c_{n_iters}i"
config_image_folder = os.path.join(base_image_folder, config)
os.makedirs(config_image_folder, exist_ok=True)
metrics[config] = _evaluate(
model,
args,
loader,
padder_mode=args.padder_type,
header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
batch_size=args.batch_size,
writer=writer,
step=step,
iterations=n_iters,
cascades=n_cascades,
save_path=config_image_folder,
save_images=args.save_images,
)
metric_log = []
metric_log_dict = {}
# print the final results
for config in metrics:
config_tokens = config.split("_")
config_iters = config_tokens[1][:-1]
config_cascades = config_tokens[0][:-1]
metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
metric_log_dict[config_cascades][config_iters] = metrics[config]
evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
metrics_str = f"Metrics: {metrics[config]}"
metric_log.extend([evaluation_str, metrics_str])
print(evaluation_str)
print(metrics_str)
eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
print("Saving eval log to: ", eval_log_name)
with open(eval_log_name, "w") as f:
f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
# write the dict line by line for each key, and each value in the keys
for config_cascades in metric_log_dict:
f.write("{\n")
f.write(f"\t{config_cascades}: {{\n")
for config_iters in metric_log_dict[config_cascades]:
# convert every metric to 4 decimal places
metrics = metric_log_dict[config_cascades][config_iters]
metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
f.write(f"\t\t{config_iters}: {metrics}\n")
f.write("\t},\n")
f.write("}\n")
def load_checkpoint(args):
utils.setup_ddp(args)
if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
if "model" in checkpoint:
experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
model.load_state_dict(checkpoint["model"])
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
model.load_state_dict(checkpoint)
# set the appropriate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
else:
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.to(device)
return model
def main(args):
model = load_checkpoint(args)
loader = make_eval_loader(args.dataset, args)
evaluate(model, loader, args)
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import argparse
from functools import partial
import torch
from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
from torchvision.datasets import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
VALID_DATASETS = {
"crestereo": partial(CREStereo),
"carla-highres": partial(CarlaStereo),
"instereo2k": partial(InStereo2k),
"sintel": partial(SintelStereo),
"sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
"sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
"sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
"fallingthings": partial(FallingThingsStereo, variant="both"),
"eth3d-train": partial(ETH3DStereo, split="train"),
"eth3d-test": partial(ETH3DStereo, split="test"),
"kitti2015-train": partial(Kitti2015Stereo, split="train"),
"kitti2015-test": partial(Kitti2015Stereo, split="test"),
"kitti2012-train": partial(Kitti2012Stereo, split="train"),
"kitti2012-test": partial(Kitti2012Stereo, split="train"),
"middlebury2014-other": partial(
Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
),
"middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
"middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
"middlebury2014-train-ambient": partial(
Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
),
}
def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
return StereoMatchingTrainPreset(
resize_size=args.resize_size,
crop_size=args.crop_size,
rescale_prob=args.rescale_prob,
scaling_type=args.scaling_type,
scale_range=args.scale_range,
scale_interpolation_type=args.interpolation_strategy,
use_grayscale=args.use_grayscale,
mean=args.norm_mean,
std=args.norm_std,
horizontal_flip_prob=args.flip_prob,
gpu_transforms=args.gpu_transforms,
max_disparity=args.max_disparity,
spatial_shift_prob=args.spatial_shift_prob,
spatial_shift_max_angle=args.spatial_shift_max_angle,
spatial_shift_max_displacement=args.spatial_shift_max_displacement,
spatial_shift_interpolation_type=args.interpolation_strategy,
gamma_range=args.gamma_range,
brightness=args.brightness_range,
contrast=args.contrast_range,
saturation=args.saturation_range,
hue=args.hue_range,
asymmetric_jitter_prob=args.asymmetric_jitter_prob,
)
def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
if args.eval_size is None:
resize_size = args.crop_size
else:
resize_size = args.eval_size
return StereoMatchingEvalPreset(
mean=args.norm_mean,
std=args.norm_std,
use_grayscale=args.use_grayscale,
resize_size=resize_size,
interpolation_type=args.interpolation_strategy,
)
def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)
from typing import Optional, Tuple, Union
import torch
import transforms as T
class StereoMatchingEvalPreset(torch.nn.Module):
def __init__(
self,
mean: float = 0.5,
std: float = 0.5,
resize_size: Optional[Tuple[int, ...]] = None,
max_disparity: Optional[float] = None,
interpolation_type: str = "bilinear",
use_grayscale: bool = False,
) -> None:
super().__init__()
transforms = [
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
]
if use_grayscale:
transforms.append(T.ConvertToGrayscale())
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
transforms.extend(
[
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity=max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparities, masks):
return self.transforms(images, disparities, masks)
class StereoMatchingTrainPreset(torch.nn.Module):
def __init__(
self,
*,
resize_size: Optional[Tuple[int, ...]],
resize_interpolation_type: str = "bilinear",
# RandomResizeAndCrop params
crop_size: Tuple[int, int],
rescale_prob: float = 1.0,
scaling_type: str = "exponential",
scale_range: Tuple[float, float] = (-0.2, 0.5),
scale_interpolation_type: str = "bilinear",
# convert to grayscale
use_grayscale: bool = False,
# normalization params
mean: float = 0.5,
std: float = 0.5,
# processing device
gpu_transforms: bool = False,
# masking
max_disparity: Optional[int] = 256,
# SpatialShift params
spatial_shift_prob: float = 0.5,
spatial_shift_max_angle: float = 0.5,
spatial_shift_max_displacement: float = 0.5,
spatial_shift_interpolation_type: str = "bilinear",
# AssymetricColorJitter
gamma_range: Tuple[float, float] = (0.8, 1.2),
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
saturation: Union[int, Tuple[int, int]] = 0.0,
hue: Union[int, Tuple[int, int]] = 0.0,
asymmetric_jitter_prob: float = 1.0,
# RandomHorizontalFlip
horizontal_flip_prob: float = 0.5,
# RandomOcclusion
occlusion_prob: float = 0.0,
occlusion_px_range: Tuple[int, int] = (50, 100),
# RandomErase
erase_prob: float = 0.0,
erase_px_range: Tuple[int, int] = (50, 100),
erase_num_repeats: int = 1,
) -> None:
if scaling_type not in ["linear", "exponential"]:
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
super().__init__()
transforms = [T.ToTensor()]
# when fixing size across multiple datasets, we ensure
# that the same size is used for all datasets when cropping
if resize_size is not None:
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
if gpu_transforms:
transforms.append(T.ToGPU())
# color handling
color_transforms = [
T.AsymmetricColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
),
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
]
if use_grayscale:
color_transforms.append(T.ConvertToGrayscale())
transforms.extend(color_transforms)
transforms.extend(
[
T.RandomSpatialShift(
p=spatial_shift_prob,
max_angle=spatial_shift_max_angle,
max_px_shift=spatial_shift_max_displacement,
interpolation_type=spatial_shift_interpolation_type,
),
T.ConvertImageDtype(torch.float32),
T.RandomRescaleAndCrop(
crop_size=crop_size,
scale_range=scale_range,
rescale_prob=rescale_prob,
scaling_type=scaling_type,
interpolation_type=scale_interpolation_type,
),
T.RandomHorizontalFlip(horizontal_flip_prob),
# occlusion after flip, otherwise we're occluding the reference image
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
T.Normalize(mean=mean, std=std),
T.MakeValidDisparityMask(max_disparity),
T.ValidateModelInput(),
]
)
self.transforms = T.Compose(transforms)
def forward(self, images, disparties, mask):
return self.transforms(images, disparties, mask)
import argparse
import os
import warnings
from pathlib import Path
from typing import List, Union
import numpy as np
import torch
import torch.distributed as dist
import torchvision.models.optical_flow
import torchvision.prototype.models.depth.stereo
import utils
import visualization
from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
from torch import nn
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
from utils.metrics import AVAILABLE_METRICS
from utils.norm import freeze_batch_norm
def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
"""Helper function to make stereo flow from a given model output"""
if isinstance(flow, list):
return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
B, C, H, W = flow.shape
# we need to add zero flow if the model outputs 2 channels
if C == 1 and model_out_channels == 2:
zero_flow = torch.zeros_like(flow)
# by convention the flow is X-Y axis, so we need the Y flow last
flow = torch.cat([flow, zero_flow], dim=1)
return flow
def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
"""Helper function to return a learning rate scheduler for CRE-stereo"""
if args.decay_after_steps < args.warmup_steps:
raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
warmup_steps = args.warmup_steps if args.warmup_steps else 0
flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
decay_lr_steps = args.total_iterations - flat_lr_steps
max_lr = args.lr
min_lr = args.min_lr
schedulers = []
milestones = []
if warmup_steps > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
)
else:
raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
schedulers.append(warmup_lr_scheduler)
milestones.append(warmup_steps)
if flat_lr_steps > 0:
flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
schedulers.append(flat_lr_scheduler)
milestones.append(flat_lr_steps + warmup_steps)
if decay_lr_steps > 0:
if args.lr_decay_method == "cosine":
decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_lr_steps, eta_min=min_lr
)
elif args.lr_decay_method == "linear":
decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
)
elif args.lr_decay_method == "exponential":
decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
)
else:
raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
schedulers.append(decay_lr_scheduler)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
return scheduler
def shuffle_dataset(dataset):
"""Shuffle the dataset"""
perm = torch.randperm(len(dataset))
return torch.utils.data.Subset(dataset, perm)
def resize_dataset_to_n_steps(
dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
) -> torch.utils.data.Dataset:
original_size = len(dataset)
if args.steps_is_epochs:
samples_per_step = original_size
target_size = dataset_steps * samples_per_step
dataset_copies = []
n_expands, remainder = divmod(target_size, original_size)
for idx in range(n_expands):
dataset_copies.append(dataset)
if remainder > 0:
dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
if args.dataset_shuffle:
dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
dataset = torch.utils.data.ConcatDataset(dataset_copies)
return dataset
def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
datasets = []
for dataset_name in args.train_datasets:
transform = make_train_transform(args)
dataset = make_dataset(dataset_name, dataset_root, transform)
datasets.append(dataset)
if len(datasets) == 0:
raise ValueError("No datasets specified for training")
samples_per_step = args.world_size * args.batch_size
for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
dataset = torch.utils.data.ConcatDataset(datasets)
if args.dataset_order_shuffle:
dataset = shuffle_dataset(dataset)
print(f"Training dataset: {len(dataset)} samples")
return dataset
@torch.inference_mode()
def _evaluate(
model,
args,
val_loader,
*,
padder_mode,
print_freq=10,
writer=None,
step=None,
iterations=None,
batch_size=None,
header=None,
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
model.eval()
header = header or "Test:"
device = torch.device(args.device)
metric_logger = utils.MetricLogger(delimiter=" ")
iterations = iterations or args.recurrent_updates
logger = utils.MetricLogger()
for meter_name in args.metrics:
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
if "fl-all" not in args.metrics:
logger.add_meter("fl-all", fmt="{global_avg:.4f}")
num_processed_samples = 0
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
for blob in metric_logger.log_every(val_loader, print_freq, header):
image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
padder = utils.InputPadder(image_left.shape, mode=padder_mode)
image_left, image_right = padder.pad(image_left, image_right)
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
disp_pred = disp_predictions[-1][:, :1, :, :]
disp_pred = padder.unpad(disp_pred)
metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
num_processed_samples += image_left.shape[0]
for name in metrics:
logger.meters[name].update(metrics[name], n=1)
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print("Num_processed_samples: ", num_processed_samples)
if (
hasattr(val_loader.dataset, "__len__")
and len(val_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
warnings.warn(
f"Number of processed samples {num_processed_samples} is different"
f"from the dataset size {len(val_loader.dataset)}. This may happen if"
"the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
)
if writer is not None and args.rank == 0:
for meter_name, meter_value in logger.meters.items():
scalar_name = f"{meter_name} {header}"
writer.add_scalar(scalar_name, meter_value.avg, step)
logger.synchronize_between_processes()
print(header, logger)
def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
if args.weights:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(image_left, image_right, disp, valid_disp_mask):
C_o, H_o, W_o = get_dimensions(image_left)
image_left, image_right = trans(image_left, image_right)
C_t, H_t, W_t = get_dimensions(image_left)
scale_factor = W_t / W_o
if disp is not None and not isinstance(disp, torch.Tensor):
disp = torch.from_numpy(disp)
if W_t != W_o:
disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
valid_disp_mask = torch.from_numpy(valid_disp_mask)
if W_t != W_o:
valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
return image_left, image_right, disp, valid_disp_mask
else:
preprocessing = make_eval_transform(args)
val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
return val_loader
def evaluate(model, loaders, args, writer=None, step=None):
for loader_name, loader in loaders.items():
_evaluate(
model,
args,
loader,
iterations=args.recurrent_updates,
padder_mode=args.padder_type,
header=f"{loader_name} evaluation",
batch_size=args.batch_size,
writer=writer,
step=step,
)
def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
device = torch.device(args.device)
# wrap the loader in a logger
loader = iter(logger.log_every(train_loader))
# output channels
model_out_channels = model.module.output_channels if args.distributed else model.output_channels
torch.set_num_threads(args.threads)
sequence_criterion = utils.SequenceLoss(
gamma=args.gamma,
max_flow=args.max_disparity,
exclude_large_flows=args.flow_loss_exclude_large,
).to(device)
if args.consistency_weight:
consistency_criterion = utils.FlowSequenceConsistencyLoss(
args.gamma,
resize_factor=0.25,
rescale_factor=0.25,
rescale_mode="bilinear",
).to(device)
else:
consistency_criterion = None
if args.psnr_weight:
psnr_criterion = utils.PSNRLoss().to(device)
else:
psnr_criterion = None
if args.smoothness_weight:
smoothness_criterion = utils.SmoothnessLoss().to(device)
else:
smoothness_criterion = None
if args.photometric_weight:
photometric_criterion = utils.FlowPhotoMetricLoss(
ssim_weight=args.photometric_ssim_weight,
max_displacement_ratio=args.photometric_max_displacement_ratio,
ssim_use_padding=False,
).to(device)
else:
photometric_criterion = None
for step in range(args.start_step + 1, args.total_iterations + 1):
data_blob = next(loader)
optimizer.zero_grad()
# unpack the data blob
image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
# different models have different outputs, make sure we get the right ones for this task
disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
# should the architecture or training loop require it, we have to adjust the disparity mask
# target to possibly look like an optical flow mask
disp_mask = make_stereo_flow(disp_mask, model_out_channels)
# sequence loss on top of the model outputs
loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
if args.consistency_weight > 0:
loss_consistency = consistency_criterion(disp_predictions)
loss += loss_consistency * args.consistency_weight
if args.psnr_weight > 0:
loss_psnr = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_psnr += psnr_criterion(
pred * valid_disp_mask.unsqueeze(1),
disp_mask * valid_disp_mask.unsqueeze(1),
).mean() # mean the psnr loss over the batch
loss += loss_psnr / len(disp_predictions) * args.psnr_weight
if args.photometric_weight > 0:
loss_photometric = 0.0
for pred in disp_predictions:
# predictions might have 1 channel, therefore we need to inpute 0s for the second channel
if model_out_channels == 1:
pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
loss_photometric += photometric_criterion(
image_left, image_right, pred, valid_disp_mask
) # photometric loss already comes out meaned over the batch
loss += loss_photometric / len(disp_predictions) * args.photometric_weight
if args.smoothness_weight > 0:
loss_smoothness = 0.0
for pred in disp_predictions:
# predictions might have 2 channels
loss_smoothness += smoothness_criterion(
image_left, pred[:, :1, :, :]
).mean() # mean the smoothness loss over the batch
loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
with torch.no_grad():
metrics, _ = utils.compute_metrics(
disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
disp_mask[:, :1, :, :], # so does the ground truth
valid_disp_mask,
args.metrics,
)
metrics.pop("fl-all", None)
logger.update(loss=loss, **metrics)
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
optimizer.step()
scheduler.step()
if not dist.is_initialized() or dist.get_rank() == 0:
if writer is not None and step % args.tensorboard_log_frequency == 0:
# log the loss and metrics to tensorboard
writer.add_scalar("loss", loss, step)
for name, value in logger.meters.items():
writer.add_scalar(name, value.avg, step)
# log the images to tensorboard
pred_grid = visualization.make_training_sample_grid(
image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
)
writer.add_image("predictions", pred_grid, step, dataformats="HWC")
# second thing we want to see is how relevant the iterative refinement is
pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
if step % args.save_frequency == 0:
if not args.distributed or args.rank == 0:
model_without_ddp = (
model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
)
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
if step % args.valid_frequency == 0:
evaluate(model, val_loaders, args, writer, step)
model.train()
if args.freeze_batch_norm:
if isinstance(model, nn.parallel.DistributedDataParallel):
freeze_batch_norm(model.module)
else:
freeze_batch_norm(model)
# one final save at the end
if not args.distributed or args.rank == 0:
model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"args": args,
}
os.makedirs(args.checkpoint_dir, exist_ok=True)
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
def main(args):
args.total_iterations = sum(args.dataset_steps)
# initialize DDP setting
utils.setup_ddp(args)
print(args)
args.test_only = args.train_datasets is None
# set the appropriate devices
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
# select model architecture
model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
# convert to DDP if need be
if args.distributed:
model = model.to(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
os.makedirs(args.checkpoint_dir, exist_ok=True)
val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
# EVAL ONLY configurations
if args.test_only:
evaluate(model, val_loaders, args)
return
# Sanity check for the parameter count
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# Compose the training dataset
train_dataset = get_train_dataset(args.dataset_root, args)
# initialize the optimizer
if args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
else:
raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
# initialize the learning rate schedule
scheduler = make_lr_schedule(args, optimizer)
# load them from checkpoint if needed
args.start_step = 0
if args.resume_path is not None:
checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
if "model" in checkpoint:
# this means the user requested to resume from a training checkpoint
model_without_ddp.load_state_dict(checkpoint["model"])
# this means the user wants to continue training from where it was left off
if args.resume_schedule:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_step = checkpoint["step"] + 1
# modify starting point of the dat
sample_start_step = args.start_step * args.batch_size * args.world_size
train_dataset = train_dataset[sample_start_step:]
else:
# this means the user wants to finetune on top of a model state dict
# and that no other changes are required
model_without_ddp.load_state_dict(checkpoint)
torch.backends.cudnn.benchmark = True
# enable training mode
model.train()
if args.freeze_batch_norm:
freeze_batch_norm(model_without_ddp)
# put dataloader on top of the dataset
# make sure to disable shuffling since the dataset is already shuffled
# in order to guarantee quasi randomness whilst retaining a deterministic
# dataset consumption order
if args.distributed:
# the train dataset is preshuffled in order to respect the iteration order
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
else:
# the train dataset is already shuffled, so we can use a simple SequentialSampler
sampler = torch.utils.data.SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.workers,
)
# initialize the logger
if args.tensorboard_summaries:
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
os.makedirs(tensorboard_path, exist_ok=True)
tensorboard_run = tensorboard_path / f"{args.name}"
writer = SummaryWriter(tensorboard_run)
else:
writer = None
logger = utils.MetricLogger(delimiter=" ")
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
# run the training loop
# this will perform optimization, respectively logging and saving checkpoints
# when need be
run(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
val_loaders=val_loaders,
logger=logger,
writer=writer,
scaler=scaler,
args=args,
)
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
# checkpointing
parser.add_argument("--name", default="crestereo", help="name of the experiment")
parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
# dataset
parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
parser.add_argument(
"--train-datasets",
type=str,
nargs="+",
default=["crestereo"],
help="dataset(s) to train on",
choices=list(VALID_DATASETS.keys()),
)
parser.add_argument(
"--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
)
parser.add_argument(
"--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
)
parser.add_argument(
"--test-datasets",
type=str,
nargs="+",
default=["middlebury2014-train"],
help="dataset(s) to test on",
choices=["middlebury2014-train"],
)
parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
parser.add_argument(
"--threads",
type=int,
default=16,
help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
)
# model architecture
parser.add_argument(
"--model",
type=str,
default="crestereo_base",
help="model architecture",
choices=["crestereo_base", "raft_stereo"],
)
parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
# loss parameters
parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
parser.add_argument(
"--flow-loss-exclude-large",
action="store_true",
help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
default=False,
)
parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
parser.add_argument(
"--consistency-resize-factor",
type=float,
default=0.25,
help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
)
parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
parser.add_argument(
"--photometric-max-displacement-ratio",
type=float,
default=0.15,
help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
)
parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
# transforms parameters
parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
parser.add_argument(
"--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
)
parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
parser.add_argument(
"--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
)
parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
parser.add_argument(
"--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
)
parser.add_argument(
"--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
)
parser.add_argument(
"--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
)
parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
parser.add_argument(
"--interpolation-strategy",
type=str,
default="bilinear",
help="interpolation strategy",
choices=["bilinear", "bicubic", "mixed"],
)
parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
parser.add_argument(
"--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
)
parser.add_argument(
"--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
)
parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
parser.add_argument(
"--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
)
parser.add_argument(
"--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
)
parser.add_argument(
"--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
)
parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
parser.add_argument(
"--asymmetric-jitter-prob",
type=float,
default=1.0,
help="probability of using asymmetric jitter instead of symmetric jitter",
)
parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
parser.add_argument(
"--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
)
parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
parser.add_argument(
"--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
)
parser.add_argument(
"--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
)
# optimizer parameters
parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
# lr_scheduler parameters
parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
parser.add_argument(
"--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
)
parser.add_argument(
"--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
)
parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
parser.add_argument(
"--lr-decay-method",
type=str,
default="linear",
help="decay method",
choices=["linear", "cosine", "exponential"],
)
parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
# deterministic behaviour
parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
# mixed precision training
parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
# logging
parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
parser.add_argument(
"--metrics",
type=str,
nargs="+",
default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
help="metrics to log",
choices=AVAILABLE_METRICS,
)
# distributed parameters
parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
# weights API
parser.add_argument("--weights", type=str, default=None, help="weights API url")
parser.add_argument(
"--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
)
parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
# padder parameters
parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
import random
from typing import Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch import Tensor
T_FLOW = Union[Tensor, np.ndarray, None]
T_MASK = Union[Tensor, np.ndarray, None]
T_STEREO_TENSOR = Tuple[Tensor, Tensor]
T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
return (low - high) * torch.rand(size) + high
class InterpolationStrategy:
_valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
def __init__(self, mode: str = "mixed") -> None:
if mode not in self._valid_modes:
raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
if mode == "mixed":
self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
elif mode == "bicubic":
self.strategies = [F.InterpolationMode.BICUBIC]
elif mode == "bilinear":
self.strategies = [F.InterpolationMode.BILINEAR]
def __call__(self) -> F.InterpolationMode:
return random.choice(self.strategies)
@classmethod
def is_valid(mode: str) -> bool:
return mode in InterpolationStrategy._valid_modes
@property
def valid_modes() -> List[str]:
return InterpolationStrategy._valid_modes
class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
if images[0].shape != images[1].shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = images[0].shape[-2:]
if disparities[0] is not None and disparities[0].shape != (1, h, w):
raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
if masks[0] is not None:
if masks[0].shape != (h, w):
raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
if masks[0].dtype != torch.bool:
raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
return images, disparities, masks
class ConvertToGrayscale(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
return (img_left, img_right), disparities, masks
class MakeValidDisparityMask(torch.nn.Module):
def __init__(self, max_disparity: Optional[int] = 256) -> None:
super().__init__()
self.max_disparity = max_disparity
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
valid_masks = tuple(
torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
for idx, mask in enumerate(masks)
)
valid_masks = tuple(
torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
if self.max_disparity is not None:
valid_masks = tuple(
torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
for mask, disparity in zip(valid_masks, disparities)
)
return images, disparities, valid_masks
class ToGPU(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
dev_images = tuple(image.cuda() for image in images)
dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
return dev_images, dev_disparities, dev_masks
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self.dtype = dtype
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class Normalize(torch.nn.Module):
def __init__(self, mean: List[float], std: List[float]) -> None:
super().__init__()
self.mean = mean
self.std = std
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left = F.normalize(images[0], mean=self.mean, std=self.std)
img_right = F.normalize(images[1], mean=self.mean, std=self.std)
img_left = img_left.contiguous()
img_right = img_right.contiguous()
return (img_left, img_right), disparities, masks
class ToTensor(torch.nn.Module):
def forward(
self,
images: Tuple[PIL.Image.Image, PIL.Image.Image],
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if images[0] is None:
raise ValueError("img_left is None")
if images[1] is None:
raise ValueError("img_right is None")
img_left = F.pil_to_tensor(images[0])
img_right = F.pil_to_tensor(images[1])
disparity_tensors = ()
mask_tensors = ()
for idx in range(2):
disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
return (img_left, img_right), disparity_tensors, mask_tensors
class AsymmetricColorJitter(T.ColorJitter):
# p determines the probability of doing asymmetric vs symmetric color jittering
def __init__(
self,
brightness: T_COLOR_AUG_PARAM = 0,
contrast: T_COLOR_AUG_PARAM = 0,
saturation: T_COLOR_AUG_PARAM = 0,
hue: T_COLOR_AUG_PARAM = 0,
p: float = 0.2,
):
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = super().forward(images[0])
img_right = super().forward(images[1])
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = super().forward(batch)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class AsymetricGammaAdjust(torch.nn.Module):
def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
super().__init__()
self.gamma_range = gamma_range
self.gain = gain
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
if torch.rand(1) < self.p:
# asymmetric: different transform for img1 and img2
img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
else:
# symmetric: same transform for img1 and img2
batch = torch.stack(images)
batch = F.adjust_gamma(batch, gamma, gain=self.gain)
img_left, img_right = batch[0], batch[1]
return (img_left, img_right), disparities, masks
class RandomErase(torch.nn.Module):
# Produces multiple symmetric random erasures
# these can be viewed as occlusions present in both camera views.
# Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
def __init__(
self,
p: float = 0.5,
erase_px_range: Tuple[int, int] = (50, 100),
value: Union[Tensor, float] = 0,
inplace: bool = False,
max_erase: int = 2,
):
super().__init__()
self.min_px_erase = erase_px_range[0]
self.max_px_erase = erase_px_range[1]
if self.max_px_erase < 0:
raise ValueError("erase_px_range[1] should be equal or greater than 0")
if self.min_px_erase < 0:
raise ValueError("erase_px_range[0] should be equal or greater than 0")
if self.min_px_erase > self.max_px_erase:
raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
self.p = p
self.value = value
self.inplace = inplace
self.max_erase = max_erase
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
if torch.rand(1) < self.p:
return images, disparities, masks
image_left, image_right = images
mask_left, mask_right = masks
for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
y, x, h, w, v = self._get_params(image_left)
image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
# similarly to optical flow occlusion prediction, we consider
# any erasure pixels that are in both images to be occluded therefore
# we mark them as invalid
if mask_left is not None:
mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
if mask_right is not None:
mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
return (image_left, image_right), disparities, (mask_left, mask_right)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_erase, self.max_px_erase),
random.randint(self.min_px_erase, self.max_px_erase),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
return crop_y, crop_x, crop_h, crop_w, self.value
class RandomOcclusion(torch.nn.Module):
# This adds an occlusion in the right image
# the occluded patch works as a patch erase where the erase value is the mean
# of the pixels from the selected zone
def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
super().__init__()
self.min_px_occlusion = occlusion_px_range[0]
self.max_px_occlusion = occlusion_px_range[1]
if self.max_px_occlusion < 0:
raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
if self.min_px_occlusion < 0:
raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
if self.min_px_occlusion > self.max_px_occlusion:
raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
self.p = p
self.inplace = inplace
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
left_image, right_image = images
if torch.rand(1) < self.p:
return images, disparities, masks
y, x, h, w, v = self._get_params(right_image)
right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
return ((left_image, right_image), disparities, masks)
def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
img_h, img_w = img.shape[-2:]
crop_h, crop_w = (
random.randint(self.min_px_occlusion, self.max_px_occlusion),
random.randint(self.min_px_occlusion, self.max_px_occlusion),
)
crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
class RandomSpatialShift(torch.nn.Module):
# This transform applies a vertical shift and a slight angle rotation and the same time
def __init__(
self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
) -> None:
super().__init__()
self.p = p
self.max_angle = max_angle
self.max_px_shift = max_px_shift
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: T_STEREO_TENSOR,
masks: T_STEREO_TENSOR,
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
# the transform is applied only on the right image
# in order to mimic slight calibration issues
img_left, img_right = images
INTERP_MODE = self._interpolation_mode_strategy()
if torch.rand(1) < self.p:
# [0, 1] -> [-a, a]
shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
# sample center point for the rotation matrix
y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
# apply affine transformations
img_right = F.affine(
img_right,
angle=angle,
translate=[0, shift], # translation only on the y-axis
center=[x, y],
scale=1.0,
shear=0.0,
interpolation=INTERP_MODE,
)
return ((img_left, img_right), disparities, masks)
class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
if dsp_right is not None and torch.rand(1) < self.p:
img_left, img_right = F.hflip(img_left), F.hflip(img_right)
dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
if mask_left is not None and mask_right is not None:
mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
return images, disparities, masks
class Resize(torch.nn.Module):
def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
super().__init__()
self.resize_size = list(resize_size) # doing this to keep mypy happy
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
resized_images = ()
resized_disparities = ()
resized_masks = ()
INTERP_MODE = self._interpolation_mode_strategy()
for img in images:
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
for dsp in disparities:
if dsp is not None:
# rescale disparity to match the new image size
scale_x = self.resize_size[1] / dsp.shape[-1]
resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
else:
resized_disparities += (None,)
for mask in masks:
if mask is not None:
resized_masks += (
# we squeeze and unsqueeze because the API requires > 3D tensors
F.resize(
mask.unsqueeze(0),
self.resize_size,
interpolation=F.InterpolationMode.NEAREST,
).squeeze(0),
)
else:
resized_masks += (None,)
return resized_images, resized_disparities, resized_masks
class RandomRescaleAndCrop(torch.nn.Module):
# This transform will resize the input with a given proba, and then crop it.
# These are the reversed operations of the built-in RandomResizedCrop,
# although the order of the operations doesn't matter too much: resizing a
# crop would give the same result as cropping a resized image, up to
# interpolation artifact at the borders of the output.
#
# The reason we don't rely on RandomResizedCrop is because of a significant
# difference in the parametrization of both transforms, in particular,
# because of the way the random parameters are sampled in both transforms,
# which leads to fairly different results (and different epe). For more details see
# https://github.com/pytorch/vision/pull/5026/files#r762932579
def __init__(
self,
crop_size: Tuple[int, int],
scale_range: Tuple[float, float] = (-0.2, 0.5),
rescale_prob: float = 0.8,
scaling_type: str = "exponential",
interpolation_type: str = "bilinear",
) -> None:
super().__init__()
self.crop_size = crop_size
self.min_scale = scale_range[0]
self.max_scale = scale_range[1]
self.rescale_prob = rescale_prob
self.scaling_type = scaling_type
self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
if self.scaling_type == "linear" and self.min_scale < 0:
raise ValueError("min_scale must be >= 0 for linear scaling")
def forward(
self,
images: T_STEREO_TENSOR,
disparities: Tuple[T_FLOW, T_FLOW],
masks: Tuple[T_MASK, T_MASK],
) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
img_left, img_right = images
dsp_left, dsp_right = disparities
mask_left, mask_right = masks
INTERP_MODE = self._interpolation_mode_strategy()
# randomly sample scale
h, w = img_left.shape[-2:]
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
# It shouldn't matter much
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
# exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
# 2 to the power of that random value. This final scale distribution will have a different
# mean and variance than a uniform distribution. Note that a scale of 1 will result in
# a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
# of 0.5X the original size.
if self.scaling_type == "exponential":
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
# linear scaling will draw a random scale in (min_scale, max_scale)
elif self.scaling_type == "linear":
scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
scale = max(scale, min_scale)
new_h, new_w = round(h * scale), round(w * scale)
if torch.rand(1).item() < self.rescale_prob:
# rescale the images
img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
resized_masks, resized_disparities = (), ()
for disparity, mask in zip(disparities, masks):
if disparity is not None:
if mask is None:
resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
# rescale the disparity
resized_disparity = (
resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
)
resized_mask = None
else:
resized_disparity, resized_mask = _resize_sparse_flow(
disparity, mask, scale_x=scale, scale_y=scale
)
resized_masks += (resized_mask,)
resized_disparities += (resized_disparity,)
else:
resized_disparities = disparities
resized_masks = masks
disparities = resized_disparities
masks = resized_masks
# Note: For sparse datasets (Kitti), the original code uses a "margin"
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
# We don't, not sure if it matters much
y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_left is not None:
dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
if dsp_right is not None:
dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks = ()
for mask in masks:
if mask is not None:
mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
cropped_masks += (mask,)
return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
def _resize_sparse_flow(
flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
) -> Tuple[Tensor, Tensor]:
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
h, w = flow.shape[-2:]
h_new = int(round(h * scale_y))
w_new = int(round(w * scale_x))
flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
ii_valid = ii_valid[within_bounds_mask]
jj_valid = jj_valid[within_bounds_mask]
ii_valid_new = ii_valid_new[within_bounds_mask]
jj_valid_new = jj_valid_new[within_bounds_mask]
valid_flow_new = flow[:, ii_valid, jj_valid]
valid_flow_new *= scale_x
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
return flow_new, valid_new.bool()
class Compose(torch.nn.Module):
def __init__(self, transforms: List[Callable]):
super().__init__()
self.transforms = transforms
@torch.inference_mode()
def forward(self, images, disparities, masks):
for t in self.transforms:
images, disparities, masks = t(images, disparities, masks)
return images, disparities, masks
from .losses import *
from .metrics import *
from .distributed import *
from .logger import *
from .padder import *
from .norm import *
import os
import torch
import torch.distributed as dist
def _redefine_print(is_main):
"""disables printing when not in main process"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def setup_ddp(args):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
args.world_size = 1
return
args.distributed = True
torch.cuda.set_device(args.gpu)
dist.init_process_group(
backend="nccl",
rank=args.rank,
world_size=args.world_size,
init_method=args.dist_url,
)
torch.distributed.barrier()
_redefine_print(is_main=(args.rank == 0))
def reduce_across_processes(val):
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
import datetime
import time
from collections import defaultdict, deque
import torch
from .distributed import reduce_across_processes
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, **kwargs):
self.meters[name] = SmoothedValue(**kwargs)
def log_every(self, iterable, print_freq=5, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if print_freq is not None and i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment