Unverified Commit 92eb12d6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor updates to optical flow ref for consistency (#5654)

* Minor updates to optical flow ref for consistency

* Actually put back name

* linting
parent e8cb0bac
...@@ -75,7 +75,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b ...@@ -75,7 +75,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=True, pin_memory=True,
num_workers=args.num_workers, num_workers=args.workers,
) )
num_flow_updates = num_flow_updates or args.num_flow_updates num_flow_updates = num_flow_updates or args.num_flow_updates
...@@ -269,17 +269,17 @@ def main(args): ...@@ -269,17 +269,17 @@ def main(args):
sampler=sampler, sampler=sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
pin_memory=True, pin_memory=True,
num_workers=args.num_workers, num_workers=args.workers,
) )
logger = utils.MetricLogger() logger = utils.MetricLogger()
done = False done = False
for current_epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {current_epoch}") print(f"EPOCH {epoch}")
if args.distributed: if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs # needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(current_epoch) sampler.set_epoch(epoch)
train_one_epoch( train_one_epoch(
model=model, model=model,
...@@ -291,20 +291,20 @@ def main(args): ...@@ -291,20 +291,20 @@ def main(args):
) )
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger) print(f"Epoch {epoch} done. ", logger)
if not args.distributed or args.rank == 0: if not args.distributed or args.rank == 0:
checkpoint = { checkpoint = {
"model": model_without_ddp.state_dict(), "model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(), "scheduler": scheduler.state_dict(),
"epoch": current_epoch, "epoch": epoch,
"args": args, "args": args,
} }
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if current_epoch % args.val_freq == 0 or done: if epoch % args.val_freq == 0 or done:
evaluate(model, args) evaluate(model, args)
model.train() model.train()
if args.freeze_batch_norm: if args.freeze_batch_norm:
...@@ -319,16 +319,14 @@ def get_args_parser(add_help=True): ...@@ -319,16 +319,14 @@ def get_args_parser(add_help=True):
type=str, type=str,
help="The name of the experiment - determines the name of the files where weights are saved.", help="The name of the experiment - determines the name of the files where weights are saved.",
) )
parser.add_argument( parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
)
parser.add_argument( parser.add_argument(
"--resume", "--resume",
type=str, type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
) )
parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
parser.add_argument( parser.add_argument(
"--train-dataset", "--train-dataset",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment