Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
46b6fb41
Unverified
Commit
46b6fb41
authored
Dec 13, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 13, 2021
Browse files
Support --epochs instead of --num-steps in optical flow references (#5082)
parent
b8b2294e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
22 deletions
+18
-22
references/optical_flow/README.md
references/optical_flow/README.md
+10
-3
references/optical_flow/train.py
references/optical_flow/train.py
+8
-19
No files found.
references/optical_flow/README.md
View file @
46b6fb41
...
@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
...
@@ -10,7 +10,14 @@ training and evaluation scripts to quickly bootstrap research.
The RAFT large model was trained on Flying Chairs and then on Flying Things.
The RAFT large model was trained on Flying Chairs and then on Flying Things.
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The
rest of the hyper-parameters are exactly the same as the original RAFT training
rest of the hyper-parameters are exactly the same as the original RAFT training
recipe from https://github.com/princeton-vl/RAFT.
recipe from https://github.com/princeton-vl/RAFT. The original recipe trains for
100000 updates (or steps) on each dataset - this corresponds to about 72 and 20
epochs on Chairs and Things respectively:
```
num_epochs = ceil(num_steps / number_of_steps_per_epoch)
= ceil(num_steps / (num_samples / effective_batch_size))
```
```
```
torchrun --nproc_per_node 8 --nnodes 1 train.py \
torchrun --nproc_per_node 8 --nnodes 1 train.py \
...
@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
...
@@ -21,7 +28,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--batch-size 2 \
--lr 0.0004 \
--lr 0.0004 \
--weight-decay 0.0001 \
--weight-decay 0.0001 \
--
num-steps 100000
\
--
epochs 72
\
--output-dir $chairs_dir
--output-dir $chairs_dir
```
```
...
@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
...
@@ -34,7 +41,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--batch-size 2 \
--batch-size 2 \
--lr 0.000125 \
--lr 0.000125 \
--weight-decay 0.0001 \
--weight-decay 0.0001 \
--
num-steps 10000
0 \
--
epochs 2
0 \
--freeze-batch-norm \
--freeze-batch-norm \
--output-dir $things_dir\
--output-dir $things_dir\
--resume $chairs_dir/$name_chairs.pth
--resume $chairs_dir/$name_chairs.pth
...
...
references/optical_flow/train.py
View file @
46b6fb41
import
argparse
import
argparse
import
warnings
import
warnings
from
math
import
ceil
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
...
@@ -168,7 +169,7 @@ def validate(model, args):
...
@@ -168,7 +169,7 @@ def validate(model, args):
warnings
.
warn
(
f
"Can't validate on
{
val_dataset
}
, skipping."
)
warnings
.
warn
(
f
"Can't validate on
{
val_dataset
}
, skipping."
)
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
current_step
,
args
):
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
args
):
for
data_blob
in
logger
.
log_every
(
train_loader
):
for
data_blob
in
logger
.
log_every
(
train_loader
):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
...
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
optimizer
.
step
()
optimizer
.
step
()
scheduler
.
step
()
scheduler
.
step
()
current_step
+=
1
if
current_step
==
args
.
num_steps
:
return
True
,
current_step
return
False
,
current_step
def
main
(
args
):
def
main
(
args
):
utils
.
setup_ddp
(
args
)
utils
.
setup_ddp
(
args
)
...
@@ -243,7 +237,8 @@ def main(args):
...
@@ -243,7 +237,8 @@ def main(args):
scheduler
=
torch
.
optim
.
lr_scheduler
.
OneCycleLR
(
scheduler
=
torch
.
optim
.
lr_scheduler
.
OneCycleLR
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
max_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
total_steps
=
args
.
num_steps
+
100
,
epochs
=
args
.
epochs
,
steps_per_epoch
=
ceil
(
len
(
train_dataset
)
/
(
args
.
world_size
*
args
.
batch_size
)),
pct_start
=
0.05
,
pct_start
=
0.05
,
cycle_momentum
=
False
,
cycle_momentum
=
False
,
anneal_strategy
=
"linear"
,
anneal_strategy
=
"linear"
,
...
@@ -252,26 +247,22 @@ def main(args):
...
@@ -252,26 +247,22 @@ def main(args):
logger
=
utils
.
MetricLogger
()
logger
=
utils
.
MetricLogger
()
done
=
False
done
=
False
current_epoch
=
current_step
=
0
for
current_epoch
in
range
(
args
.
epochs
):
while
not
done
:
print
(
f
"EPOCH
{
current_epoch
}
"
)
print
(
f
"EPOCH
{
current_epoch
}
"
)
sampler
.
set_epoch
(
current_epoch
)
# needed, otherwise the data loading order would be the same for all epochs
sampler
.
set_epoch
(
current_epoch
)
# needed, otherwise the data loading order would be the same for all epochs
done
,
current_step
=
train_one_epoch
(
train_one_epoch
(
model
=
model
,
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
scheduler
=
scheduler
,
scheduler
=
scheduler
,
train_loader
=
train_loader
,
train_loader
=
train_loader
,
logger
=
logger
,
logger
=
logger
,
current_step
=
current_step
,
args
=
args
,
args
=
args
,
)
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print
(
f
"Epoch
{
current_epoch
}
done. "
,
logger
)
print
(
f
"Epoch
{
current_epoch
}
done. "
,
logger
)
current_epoch
+=
1
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
# TODO: Also save the optimizer and scheduler
# TODO: Also save the optimizer and scheduler
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
_
{
current_epoch
}
.pth"
)
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
_
{
current_epoch
}
.pth"
)
...
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
...
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
)
)
parser
.
add_argument
(
"--val-dataset"
,
type
=
str
,
nargs
=
"+"
,
help
=
"The dataset(s) to use for validation."
)
parser
.
add_argument
(
"--val-dataset"
,
type
=
str
,
nargs
=
"+"
,
help
=
"The dataset(s) to use for validation."
)
parser
.
add_argument
(
"--val-freq"
,
type
=
int
,
default
=
2
,
help
=
"Validate every X epochs"
)
parser
.
add_argument
(
"--val-freq"
,
type
=
int
,
default
=
2
,
help
=
"Validate every X epochs"
)
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
20
,
help
=
"The total number of epochs to train."
)
# Keeping it this way for now to reproduce results more easily.
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--num-steps"
,
type
=
int
,
default
=
100000
,
help
=
"The total number of steps (updates) to train."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.00002
,
help
=
"Learning rate for AdamW optimizer"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.00002
,
help
=
"Learning rate for AdamW optimizer"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.00005
,
help
=
"Weight decay for AdamW optimizer"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.00005
,
help
=
"Weight decay for AdamW optimizer"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment