Unverified Commit 3aa2a93d authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

RAFT training reference Improvement (#5590)



* Change optical flow train.py function name from validate to evaluate so it is similar to other references

* Add --device as parameter and enable to run in non distributed mode

* Format with ufmt

* Fix unneccessary param and bug

* Enable saving the optimizer and scheduler on the checkpoint

* Fix bug when evaluate before resume and save or load model without ddp

* Fix case where --train-dataset is None
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 7be2f55b
...@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root): ...@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
@torch.no_grad() @torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset. """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker. We process as many samples as possible with ddp, and process the rest on a single worker.
""" """
batch_size = batch_size or args.batch_size batch_size = batch_size or args.batch_size
device = torch.device(args.device)
model.eval() model.eval()
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
val_dataset, val_dataset,
sampler=sampler, sampler=sampler,
...@@ -88,7 +93,7 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b ...@@ -88,7 +93,7 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
image1, image2, flow_gt = blob[:3] image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1] valid_flow_mask = None if len(blob) == 3 else blob[-1]
image1, image2 = image1.cuda(), image2.cuda() image1, image2 = image1.to(device), image2.to(device)
padder = utils.InputPadder(image1.shape, mode=padder_mode) padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2) image1, image2 = padder.pad(image1, image2)
...@@ -115,21 +120,22 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b ...@@ -115,21 +120,22 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
inner_loop(blob) inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size num_processed_samples += blob[0].shape[0] # batch size
if args.distributed:
num_processed_samples = utils.reduce_across_processes(num_processed_samples) num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print( print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any." "Going to process the remaining samples individually, if any."
) )
if args.rank == 0: # we only need to process the rest on a single worker if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)): for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i]) inner_loop(val_dataset[i])
logger.synchronize_between_processes() logger.synchronize_between_processes()
print(header, logger) print(header, logger)
def validate(model, args): def evaluate(model, args):
val_datasets = args.val_dataset or [] val_datasets = args.val_dataset or []
if args.prototype: if args.prototype:
...@@ -145,13 +151,13 @@ def validate(model, args): ...@@ -145,13 +151,13 @@ def validate(model, args):
if name == "kitti": if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch. # Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder # see comment in InputPadder
if args.batch_size != 1 and args.rank == 0: if args.batch_size != 1 and (not args.distributed or args.rank == 0):
warnings.warn( warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
) )
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
_validate( _evaluate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
) )
elif name == "sintel": elif name == "sintel":
...@@ -159,7 +165,7 @@ def validate(model, args): ...@@ -159,7 +165,7 @@ def validate(model, args):
val_dataset = Sintel( val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
) )
_validate( _evaluate(
model, model,
args, args,
val_dataset, val_dataset,
...@@ -172,11 +178,12 @@ def validate(model, args): ...@@ -172,11 +178,12 @@ def validate(model, args):
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
device = torch.device(args.device)
for data_blob in logger.log_every(train_loader): for data_blob in logger.log_every(train_loader):
optimizer.zero_grad() optimizer.zero_grad()
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
...@@ -200,36 +207,68 @@ def main(args): ...@@ -200,36 +207,68 @@ def main(args):
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args) utils.setup_ddp(args)
if args.distributed and args.device == "cpu":
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)
if args.prototype: if args.prototype:
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else: else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
if args.distributed:
model = model.to(args.local_rank) model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
model_without_ddp = model.module
else:
model.to(device)
model_without_ddp = model
if args.resume is not None: if args.resume is not None:
d = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True) model_without_ddp.load_state_dict(checkpoint["model"])
if args.train_dataset is None: if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit. # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
validate(model, args) evaluate(model, args)
return return
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
if args.resume is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
else:
args.start_epoch = 0
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
model.train() model.train()
if args.freeze_batch_norm: if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module) utils.freeze_batch_norm(model.module)
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
else:
sampler = torch.utils.data.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
sampler=sampler, sampler=sampler,
...@@ -238,25 +277,15 @@ def main(args): ...@@ -238,25 +277,15 @@ def main(args):
num_workers=args.num_workers, num_workers=args.num_workers,
) )
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
logger = utils.MetricLogger() logger = utils.MetricLogger()
done = False done = False
for current_epoch in range(args.epochs): for current_epoch in range(args.start_epoch, args.epochs):
print(f"EPOCH {current_epoch}") print(f"EPOCH {current_epoch}")
if args.distributed:
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
sampler.set_epoch(current_epoch)
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
train_one_epoch( train_one_epoch(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
...@@ -269,13 +298,19 @@ def main(args): ...@@ -269,13 +298,19 @@ def main(args):
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger) print(f"Epoch {current_epoch} done. ", logger)
if args.rank == 0: if not args.distributed or args.rank == 0:
# TODO: Also save the optimizer and scheduler checkpoint = {
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") "model": model_without_ddp.state_dict(),
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") "optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": current_epoch,
"args": args,
}
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
if current_epoch % args.val_freq == 0 or done: if current_epoch % args.val_freq == 0 or done:
validate(model, args) evaluate(model, args)
model.train() model.train()
if args.freeze_batch_norm: if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module) utils.freeze_batch_norm(model.module)
...@@ -349,6 +384,7 @@ def get_args_parser(add_help=True): ...@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
action="store_true", action="store_true",
) )
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
return parser return parser
......
...@@ -256,7 +256,12 @@ def setup_ddp(args): ...@@ -256,7 +256,12 @@ def setup_ddp(args):
# if we're here, the script was called by run_with_submitit.py # if we're here, the script was called by run_with_submitit.py
args.local_rank = args.gpu args.local_rank = args.gpu
else: else:
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.") print("Not using distributed mode!")
args.distributed = False
args.world_size = 1
return
args.distributed = True
_redefine_print(is_main=(args.rank == 0)) _redefine_print(is_main=(args.rank == 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment