Unverified Commit 46b6fb41 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Support --epochs instead of --num-steps in optical flow references (#5082)

parent b8b2294e
......@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
......@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--num-steps 100000 \
--epochs 72 \
--output-dir $chairs_dir
```
......@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--num-steps 100000 \
--epochs 20 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
......
import argparse
import warnings
from math import ceil
from pathlib import Path
import torch
......@@ -168,7 +169,7 @@ def validate(model, args):
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
for data_blob in logger.log_every(train_loader):
optimizer.zero_grad()
......@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
optimizer.step()
scheduler.step()
current_step += 1
if current_step == args.num_steps:
return True, current_step
return False, current_step
def main(args):
utils.setup_ddp(args)
......@@ -243,7 +237,8 @@ def main(args):
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
total_steps=args.num_steps + 100,
epochs=args.epochs,
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
......@@ -252,26 +247,22 @@ def main(args):
logger = utils.MetricLogger()
done = False
current_epoch = current_step = 0
while not done:
for current_epoch in range(args.epochs):
print(f"EPOCH {current_epoch}")
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
done, current_step = train_one_epoch(
train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
current_step=current_step,
args=args,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)
current_epoch += 1
if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
......@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
parser.add_argument("--batch-size", type=int, default=6)
parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment