Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
46b6fb41
Unverified
Commit
46b6fb41
authored
Dec 13, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 13, 2021
Browse files
Support --epochs instead of --num-steps in optical flow references (#5082)
parent
b8b2294e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
22 deletions
+18
-22
references/optical_flow/README.md
references/optical_flow/README.md
+10
-3
references/optical_flow/train.py
references/optical_flow/train.py
+8
-19
No files found.
references/optical_flow/README.md
View file @
46b6fb41
...
...
@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
...
...
@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--lr 0.0004 \
--weight-decay 0.0001 \
--
num-steps 100000
\
--
epochs 72
\
--output-dir $chairs_dir
```
...
...
@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--lr 0.000125 \
--weight-decay 0.0001 \
--
num-steps 10000
0 \
--
epochs 2
0 \
--freeze-batch-norm \
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
...
...
references/optical_flow/train.py
View file @
46b6fb41
import
argparse
import
warnings
from
math
import
ceil
from
pathlib
import
Path
import
torch
...
...
@@ -168,7 +169,7 @@ def validate(model, args):
warnings
.
warn
(
f
"Can't validate on
{
val_dataset
}
, skipping."
)
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
current_step
,
args
):
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
args
):
for
data_blob
in
logger
.
log_every
(
train_loader
):
optimizer
.
zero_grad
()
...
...
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
optimizer
.
step
()
scheduler
.
step
()
current_step
+=
1
if
current_step
==
args
.
num_steps
:
return
True
,
current_step
return
False
,
current_step
def
main
(
args
):
utils
.
setup_ddp
(
args
)
...
...
@@ -243,7 +237,8 @@ def main(args):
scheduler
=
torch
.
optim
.
lr_scheduler
.
OneCycleLR
(
optimizer
=
optimizer
,
max_lr
=
args
.
lr
,
total_steps
=
args
.
num_steps
+
100
,
epochs
=
args
.
epochs
,
steps_per_epoch
=
ceil
(
len
(
train_dataset
)
/
(
args
.
world_size
*
args
.
batch_size
)),
pct_start
=
0.05
,
cycle_momentum
=
False
,
anneal_strategy
=
"linear"
,
...
...
@@ -252,26 +247,22 @@ def main(args):
logger
=
utils
.
MetricLogger
()
done
=
False
current_epoch
=
current_step
=
0
while
not
done
:
for
current_epoch
in
range
(
args
.
epochs
):
print
(
f
"EPOCH
{
current_epoch
}
"
)
sampler
.
set_epoch
(
current_epoch
)
# needed, otherwise the data loading order would be the same for all epochs
done
,
current_step
=
train_one_epoch
(
train_one_epoch
(
model
=
model
,
optimizer
=
optimizer
,
scheduler
=
scheduler
,
train_loader
=
train_loader
,
logger
=
logger
,
current_step
=
current_step
,
args
=
args
,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print
(
f
"Epoch
{
current_epoch
}
done. "
,
logger
)
current_epoch
+=
1
if
args
.
rank
==
0
:
# TODO: Also save the optimizer and scheduler
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
_
{
current_epoch
}
.pth"
)
...
...
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
)
parser
.
add_argument
(
"--val-dataset"
,
type
=
str
,
nargs
=
"+"
,
help
=
"The dataset(s) to use for validation."
)
parser
.
add_argument
(
"--val-freq"
,
type
=
int
,
default
=
2
,
help
=
"Validate every X epochs"
)
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser
.
add_argument
(
"--num-steps"
,
type
=
int
,
default
=
100000
,
help
=
"The total number of steps (updates) to train."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
20
,
help
=
"The total number of epochs to train."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.00002
,
help
=
"Learning rate for AdamW optimizer"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.00005
,
help
=
"Weight decay for AdamW optimizer"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment