Unverified Commit 46b6fb41 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Support --epochs instead of --num-steps in optical flow references (#5082)

parent b8b2294e
...@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research. ...@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
The RAFT large model was trained on Flying Chairs and then on Flying Things. The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT. recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
``` ```
torchrun --nproc_per_node 8 --nnodes 1 train.py \ torchrun --nproc_per_node 8 --nnodes 1 train.py \
...@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ...@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \ --batch-size 2 \
--lr 0.0004 \ --lr 0.0004 \
--weight-decay 0.0001 \ --weight-decay 0.0001 \
--num-steps 100000 \ --epochs 72 \
--output-dir $chairs_dir --output-dir $chairs_dir
``` ```
...@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ...@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \ --batch-size 2 \
--lr 0.000125 \ --lr 0.000125 \
--weight-decay 0.0001 \ --weight-decay 0.0001 \
--num-steps 100000 \ --epochs 20 \
--freeze-batch-norm \ --freeze-batch-norm \
--output-dir $things_dir\ --output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth --resume $chairs_dir/$name_chairs.pth
......
import argparse import argparse
import warnings import warnings
from math import ceil
from pathlib import Path from pathlib import Path
import torch import torch
...@@ -168,7 +169,7 @@ def validate(model, args): ...@@ -168,7 +169,7 @@ def validate(model, args):
warnings.warn(f"Can't validate on {val_dataset}, skipping.") warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args): def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
for data_blob in logger.log_every(train_loader): for data_blob in logger.log_every(train_loader):
optimizer.zero_grad() optimizer.zero_grad()
...@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s ...@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
current_step += 1
if current_step == args.num_steps:
return True, current_step
return False, current_step
def main(args): def main(args):
utils.setup_ddp(args) utils.setup_ddp(args)
...@@ -243,7 +237,8 @@ def main(args): ...@@ -243,7 +237,8 @@ def main(args):
scheduler = torch.optim.lr_scheduler.OneCycleLR( scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer, optimizer=optimizer,
max_lr=args.lr, max_lr=args.lr,
total_steps=args.num_steps + 100, epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05, pct_start=0.05,
cycle_momentum=False, cycle_momentum=False,
anneal_strategy="linear", anneal_strategy="linear",
...@@ -252,26 +247,22 @@ def main(args): ...@@ -252,26 +247,22 @@ def main(args):
logger = utils.MetricLogger() logger = utils.MetricLogger()
done = False done = False
current_epoch = current_step = 0 for current_epoch in range(args.epochs):
while not done:
print(f"EPOCH {current_epoch}") print(f"EPOCH {current_epoch}")
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
done, current_step = train_one_epoch( train_one_epoch(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
train_loader=train_loader, train_loader=train_loader,
logger=logger, logger=logger,
current_step=current_step,
args=args, args=args,
) )
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger) print(f"Epoch {current_epoch} done. ", logger)
current_epoch += 1
if args.rank == 0: if args.rank == 0:
# TODO: Also save the optimizer and scheduler # TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
...@@ -310,10 +301,8 @@ def get_args_parser(add_help=True): ...@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
) )
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
# TODO: eventually, it might be preferable to support epochs instead of num_steps. parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
# Keeping it this way for now to reproduce results more easily. parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
parser.add_argument("--batch-size", type=int, default=6)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment