Unverified Commit 3aa2a93d authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

RAFT training reference Improvement (#5590)



* Change optical flow train.py function name from validate to evaluate so it is similar to other references

* Add --device as parameter and enable to run in non distributed mode

* Format with ufmt

* Fix unneccessary param and bug

* Enable saving the optimizer and scheduler on the checkpoint

* Fix bug when evaluate before resume and save or load model without ddp

* Fix case where --train-dataset is None
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 7be2f55b
......@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size
device = torch.device(args.device)
model.eval()
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
......@@ -88,7 +93,7 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.cuda(), image2.cuda()
image1, image2 = image1.to(device), image2.to(device)
padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)
......@@ -115,21 +120,22 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])
logger.synchronize_between_processes()
logger.synchronize_between_processes()
print(header, logger)
def validate(model, args):
def evaluate(model, args):
val_datasets = args.val_dataset or []
if args.prototype:
......@@ -145,13 +151,13 @@ def validate(model, args):
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and args.rank == 0:
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate(
_evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
......@@ -159,7 +165,7 @@ def validate(model, args):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
)
_validate(
_evaluate(
model,
args,
val_dataset,
......@@ -172,11 +178,12 @@ def validate(model, args):
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
......@@ -200,36 +207,68 @@ def main(args):
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args)
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
if args.prototype:
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
if args.distributed:
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
if args.resume is not None:
d = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True)
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
validate(model, args)
evaluate(model, args)
return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0
torch.backends.cudnn.benchmark = True
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
......@@ -238,25 +277,15 @@ def main(args):
num_workers=args.num_workers,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
logger = utils.MetricLogger()
done = False
for current_epoch in range(args.epochs):
for current_epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {current_epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(current_epoch)
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
train_one_epoch(
model=model,
optimizer=optimizer,
......@@ -269,13 +298,19 @@ def main(args):
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)
if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
if not args.distributed or args.rank == 0:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": current_epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if current_epoch % args.val_freq == 0 or done:
validate(model, args)
evaluate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)
......@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
return parser
......
......@@ -256,7 +256,12 @@ def setup_ddp(args):
# if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu
else:
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.")
print("Not using distributed mode!")
args.distributed = False
args.world_size = 1
return
args.distributed = True
_redefine_print(is_main=(args.rank == 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment