Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4dd8b5cc
Unverified
Commit
4dd8b5cc
authored
Dec 07, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 07, 2021
Browse files
Add training reference for optical flow models (#5027)
parent
47bd9620
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
616 additions
and
0 deletions
+616
-0
references/optical_flow/train.py
references/optical_flow/train.py
+334
-0
references/optical_flow/utils.py
references/optical_flow/utils.py
+282
-0
No files found.
references/optical_flow/train.py
0 → 100644
View file @
4dd8b5cc
import
argparse
import
warnings
from
pathlib
import
Path
import
torch
import
utils
from
presets
import
OpticalFlowPresetTrain
,
OpticalFlowPresetEval
from
torchvision.datasets
import
KittiFlow
,
FlyingChairs
,
FlyingThings3D
,
Sintel
,
HD1K
from
torchvision.models.optical_flow
import
raft_large
,
raft_small
def
get_train_dataset
(
stage
,
dataset_root
):
if
stage
==
"chairs"
:
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
(
368
,
496
),
min_scale
=
0.1
,
max_scale
=
1.0
,
do_flip
=
True
)
return
FlyingChairs
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
transforms
)
elif
stage
==
"things"
:
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
(
400
,
720
),
min_scale
=-
0.4
,
max_scale
=
0.8
,
do_flip
=
True
)
return
FlyingThings3D
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"both"
,
transforms
=
transforms
)
elif
stage
==
"sintel_SKH"
:
# S + K + H as from paper
crop_size
=
(
368
,
768
)
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.2
,
max_scale
=
0.6
,
do_flip
=
True
)
things_clean
=
FlyingThings3D
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
transforms
)
sintel
=
Sintel
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"both"
,
transforms
=
transforms
)
kitti_transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.3
,
max_scale
=
0.5
,
do_flip
=
True
)
kitti
=
KittiFlow
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
kitti_transforms
)
hd1k_transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.5
,
max_scale
=
0.2
,
do_flip
=
True
)
hd1k
=
HD1K
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
hd1k_transforms
)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return
100
*
sintel
+
200
*
kitti
+
5
*
hd1k
+
things_clean
elif
stage
==
"kitti"
:
transforms
=
OpticalFlowPresetTrain
(
# resize and crop params
crop_size
=
(
288
,
960
),
min_scale
=-
0.2
,
max_scale
=
0.4
,
stretch_prob
=
0
,
# flip params
do_flip
=
False
,
# jitter params
brightness
=
0.3
,
contrast
=
0.3
,
saturation
=
0.3
,
hue
=
0.3
/
3.14
,
asymmetric_jitter_prob
=
0
,
)
return
KittiFlow
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
transforms
)
else
:
raise
ValueError
(
f
"Unknown stage
{
stage
}
"
)
@
torch
.
no_grad
()
def
_validate
(
model
,
args
,
val_dataset
,
*
,
padder_mode
,
num_flow_updates
=
None
,
batch_size
=
None
,
header
=
None
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size
=
batch_size
or
args
.
batch_size
model
.
eval
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
,
shuffle
=
False
,
drop_last
=
True
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
sampler
=
sampler
,
batch_size
=
batch_size
,
pin_memory
=
True
,
num_workers
=
args
.
num_workers
,
)
num_flow_updates
=
num_flow_updates
or
args
.
num_flow_updates
def
inner_loop
(
blob
):
if
blob
[
0
].
dim
()
==
3
:
# input is not batched so we add an extra dim for consistency
blob
=
[
x
[
None
,
:,
:,
:]
if
x
is
not
None
else
None
for
x
in
blob
]
image1
,
image2
,
flow_gt
=
blob
[:
3
]
valid_flow_mask
=
None
if
len
(
blob
)
==
3
else
blob
[
-
1
]
image1
,
image2
=
image1
.
cuda
(),
image2
.
cuda
()
padder
=
utils
.
InputPadder
(
image1
.
shape
,
mode
=
padder_mode
)
image1
,
image2
=
padder
.
pad
(
image1
,
image2
)
flow_predictions
=
model
(
image1
,
image2
,
num_flow_updates
=
num_flow_updates
)
flow_pred
=
flow_predictions
[
-
1
]
flow_pred
=
padder
.
unpad
(
flow_pred
).
cpu
()
metrics
,
num_pixels_tot
=
utils
.
compute_metrics
(
flow_pred
,
flow_gt
,
valid_flow_mask
)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for
name
in
(
"epe"
,
"1px"
,
"3px"
,
"5px"
,
"f1"
):
# f1 is called f1-all in paper
logger
.
meters
[
name
].
update
(
metrics
[
name
],
n
=
num_pixels_tot
)
logger
.
meters
[
"per_image_epe"
].
update
(
metrics
[
"epe"
],
n
=
batch_size
)
logger
=
utils
.
MetricLogger
()
for
meter_name
in
(
"epe"
,
"1px"
,
"3px"
,
"5px"
,
"per_image_epe"
,
"f1"
):
logger
.
add_meter
(
meter_name
,
fmt
=
"{global_avg:.4f}"
)
num_processed_samples
=
0
for
blob
in
logger
.
log_every
(
val_loader
,
header
=
header
,
print_freq
=
None
):
inner_loop
(
blob
)
num_processed_samples
+=
blob
[
0
].
shape
[
0
]
# batch size
num_processed_samples
=
utils
.
reduce_across_processes
(
num_processed_samples
)
print
(
f
"Batch-processed
{
num_processed_samples
}
/
{
len
(
val_dataset
)
}
samples. "
"Going to process the remaining samples individually, if any."
)
if
args
.
rank
==
0
:
# we only need to process the rest on a single worker
for
i
in
range
(
num_processed_samples
,
len
(
val_dataset
)):
inner_loop
(
val_dataset
[
i
])
logger
.
synchronize_between_processes
()
print
(
header
,
logger
)
def
validate
(
model
,
args
):
val_datasets
=
args
.
val_dataset
or
[]
for
name
in
val_datasets
:
if
name
==
"kitti"
:
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if
args
.
batch_size
!=
1
and
args
.
rank
==
0
:
warnings
.
warn
(
f
"Batch-size=
{
args
.
batch_size
}
was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset
=
KittiFlow
(
root
=
args
.
dataset_root
,
split
=
"train"
,
transforms
=
OpticalFlowPresetEval
())
_validate
(
model
,
args
,
val_dataset
,
num_flow_updates
=
24
,
padder_mode
=
"kitti"
,
header
=
"Kitti val"
,
batch_size
=
1
)
elif
name
==
"sintel"
:
for
pass_name
in
(
"clean"
,
"final"
):
val_dataset
=
Sintel
(
root
=
args
.
dataset_root
,
split
=
"train"
,
pass_name
=
pass_name
,
transforms
=
OpticalFlowPresetEval
()
)
_validate
(
model
,
args
,
val_dataset
,
num_flow_updates
=
32
,
padder_mode
=
"sintel"
,
header
=
f
"Sintel val
{
pass_name
}
"
,
)
else
:
warnings
.
warn
(
f
"Can't validate on
{
val_dataset
}
, skipping."
)
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
current_step
,
args
):
for
data_blob
in
logger
.
log_every
(
train_loader
):
optimizer
.
zero_grad
()
image1
,
image2
,
flow_gt
,
valid_flow_mask
=
(
x
.
cuda
()
for
x
in
data_blob
)
flow_predictions
=
model
(
image1
,
image2
,
num_flow_updates
=
args
.
num_flow_updates
)
loss
=
utils
.
sequence_loss
(
flow_predictions
,
flow_gt
,
valid_flow_mask
,
args
.
gamma
)
metrics
,
_
=
utils
.
compute_metrics
(
flow_predictions
[
-
1
],
flow_gt
,
valid_flow_mask
)
metrics
.
pop
(
"f1"
)
logger
.
update
(
loss
=
loss
,
**
metrics
)
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
max_norm
=
1
)
optimizer
.
step
()
scheduler
.
step
()
current_step
+=
1
if
current_step
==
args
.
num_steps
:
return
True
,
current_step
return
False
,
current_step
def
main
(
args
):
utils
.
setup_ddp
(
args
)
model
=
raft_small
()
if
args
.
small
else
raft_large
()
model
=
model
.
to
(
args
.
local_rank
)
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
local_rank
])
if
args
.
resume
is
not
None
:
d
=
torch
.
load
(
args
.
resume
,
map_location
=
"cpu"
)
model
.
load_state_dict
(
d
,
strict
=
True
)
if
args
.
train_dataset
is
None
:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
validate
(
model
,
args
)
return
print
(
f
"Parameter Count:
{
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
}
"
)
torch
.
backends
.
cudnn
.
benchmark
=
True
model
.
train
()
if
args
.
freeze_batch_norm
:
utils
.
freeze_batch_norm
(
model
.
module
)
train_dataset
=
get_train_dataset
(
args
.
train_dataset
,
args
.
dataset_root
)
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
,
shuffle
=
True
,
drop_last
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
sampler
=
sampler
,
batch_size
=
args
.
batch_size
,
pin_memory
=
True
,
num_workers
=
args
.
num_workers
,
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
eps
=
args
.
adamw_eps
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
OneCycleLR
(
optimizer
=
optimizer
,
max_lr
=
args
.
lr
,
total_steps
=
args
.
num_steps
+
100
,
pct_start
=
0.05
,
cycle_momentum
=
False
,
anneal_strategy
=
"linear"
,
)
logger
=
utils
.
MetricLogger
()
done
=
False
current_epoch
=
current_step
=
0
while
not
done
:
print
(
f
"EPOCH
{
current_epoch
}
"
)
sampler
.
set_epoch
(
current_epoch
)
# needed, otherwise the data loading order would be the same for all epochs
done
,
current_step
=
train_one_epoch
(
model
=
model
,
optimizer
=
optimizer
,
scheduler
=
scheduler
,
train_loader
=
train_loader
,
logger
=
logger
,
current_step
=
current_step
,
args
=
args
,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print
(
f
"Epoch
{
current_epoch
}
done. "
,
logger
)
current_epoch
+=
1
if
args
.
rank
==
0
:
# TODO: Also save the optimizer and scheduler
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
_
{
current_epoch
}
.pth"
)
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
.pth"
)
if
current_epoch
%
args
.
val_freq
==
0
or
done
:
validate
(
model
,
args
)
model
.
train
()
if
args
.
freeze_batch_norm
:
utils
.
freeze_batch_norm
(
model
.
module
)
def
get_args_parser
(
add_help
=
True
):
parser
=
argparse
.
ArgumentParser
(
add_help
=
add_help
,
description
=
"Train or evaluate an optical-flow model."
)
parser
.
add_argument
(
"--name"
,
default
=
"raft"
,
type
=
str
,
help
=
"The name of the experiment - determines the name of the files where weights are saved."
,
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"checkpoints"
,
type
=
str
,
help
=
"Output dir where checkpoints will be stored."
)
parser
.
add_argument
(
"--resume"
,
type
=
str
,
help
=
"A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model."
,
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
12
,
help
=
"Number of workers for the data loading part."
)
parser
.
add_argument
(
"--train-dataset"
,
type
=
str
,
help
=
"The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume)."
,
)
parser
.
add_argument
(
"--val-dataset"
,
type
=
str
,
nargs
=
"+"
,
help
=
"The dataset(s) to use for validation."
)
parser
.
add_argument
(
"--val-freq"
,
type
=
int
,
default
=
2
,
help
=
"Validate every X epochs"
)
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser
.
add_argument
(
"--num-steps"
,
type
=
int
,
default
=
100000
,
help
=
"The total number of steps (updates) to train."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.00002
,
help
=
"Learning rate for AdamW optimizer"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.00005
,
help
=
"Weight decay for AdamW optimizer"
)
parser
.
add_argument
(
"--adamw-eps"
,
type
=
float
,
default
=
1e-8
,
help
=
"eps value for AdamW optimizer"
)
parser
.
add_argument
(
"--freeze-batch-norm"
,
action
=
"store_true"
,
help
=
"Set BatchNorm modules of the model in eval mode."
)
parser
.
add_argument
(
"--small"
,
action
=
"store_true"
,
help
=
"Use the 'small' RAFT architecture."
)
parser
.
add_argument
(
"--num_flow_updates"
,
type
=
int
,
default
=
12
,
help
=
"number of updates (or 'iters') in the update operator of the model."
,
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
0.8
,
help
=
"exponential weighting for loss. Must be < 1."
)
parser
.
add_argument
(
"--dist-url"
,
default
=
"env://"
,
help
=
"URL used to set up distributed training"
)
parser
.
add_argument
(
"--dataset-root"
,
help
=
"Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets."
,
required
=
True
,
)
return
parser
if
__name__
==
"__main__"
:
args
=
get_args_parser
().
parse_args
()
Path
(
args
.
output_dir
).
mkdir
(
exist_ok
=
True
)
main
(
args
)
references/optical_flow/utils.py
0 → 100644
View file @
4dd8b5cc
import
datetime
import
os
import
time
from
collections
import
defaultdict
from
collections
import
deque
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def
__init__
(
self
,
window_size
=
20
,
fmt
=
"{median:.4f} ({global_avg:.4f})"
):
self
.
deque
=
deque
(
maxlen
=
window_size
)
self
.
total
=
0.0
self
.
count
=
0
self
.
fmt
=
fmt
def
update
(
self
,
value
,
n
=
1
):
self
.
deque
.
append
(
value
)
self
.
count
+=
n
self
.
total
+=
value
*
n
def
synchronize_between_processes
(
self
):
"""
Warning: does not synchronize the deque!
"""
t
=
reduce_across_processes
([
self
.
count
,
self
.
total
])
t
=
t
.
tolist
()
self
.
count
=
int
(
t
[
0
])
self
.
total
=
t
[
1
]
@
property
def
median
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
))
return
d
.
median
().
item
()
@
property
def
avg
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
),
dtype
=
torch
.
float32
)
return
d
.
mean
().
item
()
@
property
def
global_avg
(
self
):
return
self
.
total
/
self
.
count
@
property
def
max
(
self
):
return
max
(
self
.
deque
)
@
property
def
value
(
self
):
return
self
.
deque
[
-
1
]
def
__str__
(
self
):
return
self
.
fmt
.
format
(
median
=
self
.
median
,
avg
=
self
.
avg
,
global_avg
=
self
.
global_avg
,
max
=
self
.
max
,
value
=
self
.
value
)
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
def
update
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
item
()
assert
isinstance
(
v
,
(
float
,
int
))
self
.
meters
[
k
].
update
(
v
)
def
__getattr__
(
self
,
attr
):
if
attr
in
self
.
meters
:
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
for
meter
in
self
.
meters
.
values
():
meter
.
synchronize_between_processes
()
def
add_meter
(
self
,
name
,
**
kwargs
):
self
.
meters
[
name
]
=
SmoothedValue
(
**
kwargs
)
def
log_every
(
self
,
iterable
,
print_freq
=
5
,
header
=
None
):
i
=
0
if
not
header
:
header
=
""
start_time
=
time
.
time
()
end
=
time
.
time
()
iter_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
data_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
space_fmt
=
":"
+
str
(
len
(
str
(
len
(
iterable
))))
+
"d"
if
torch
.
cuda
.
is_available
():
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
,
"max mem: {memory:.0f}"
,
]
)
else
:
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
]
)
MB
=
1024.0
*
1024.0
for
obj
in
iterable
:
data_time
.
update
(
time
.
time
()
-
end
)
yield
obj
iter_time
.
update
(
time
.
time
()
-
end
)
if
print_freq
is
not
None
and
i
%
print_freq
==
0
:
eta_seconds
=
iter_time
.
global_avg
*
(
len
(
iterable
)
-
i
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
torch
.
cuda
.
is_available
():
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
memory
=
torch
.
cuda
.
max_memory_allocated
()
/
MB
,
)
)
else
:
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
)
)
)
i
+=
1
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
def
compute_metrics
(
flow_pred
,
flow_gt
,
valid_flow_mask
=
None
):
epe
=
((
flow_pred
-
flow_gt
)
**
2
).
sum
(
dim
=
1
).
sqrt
()
flow_norm
=
(
flow_gt
**
2
).
sum
(
dim
=
1
).
sqrt
()
if
valid_flow_mask
is
not
None
:
epe
=
epe
[
valid_flow_mask
]
flow_norm
=
flow_norm
[
valid_flow_mask
]
relative_epe
=
epe
/
flow_norm
metrics
=
{
"epe"
:
epe
.
mean
().
item
(),
"1px"
:
(
epe
<
1
).
float
().
mean
().
item
(),
"3px"
:
(
epe
<
3
).
float
().
mean
().
item
(),
"5px"
:
(
epe
<
5
).
float
().
mean
().
item
(),
"f1"
:
((
epe
>
3
)
&
(
relative_epe
>
0.05
)).
float
().
mean
().
item
()
*
100
,
}
return
metrics
,
epe
.
numel
()
def
sequence_loss
(
flow_preds
,
flow_gt
,
valid_flow_mask
,
gamma
=
0.8
,
max_flow
=
400
):
"""Loss function defined over sequence of flow predictions"""
if
gamma
>
1
:
raise
ValueError
(
f
"Gamma should be < 1, got
{
gamma
}
."
)
# exlude invalid pixels and extremely large diplacements
flow_norm
=
torch
.
sum
(
flow_gt
**
2
,
dim
=
1
).
sqrt
()
valid_flow_mask
=
valid_flow_mask
&
(
flow_norm
<
max_flow
)
valid_flow_mask
=
valid_flow_mask
[:,
None
,
:,
:]
flow_preds
=
torch
.
stack
(
flow_preds
)
# shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff
=
(
flow_preds
-
flow_gt
).
abs
()
abs_diff
=
(
abs_diff
*
valid_flow_mask
).
mean
(
axis
=
(
1
,
2
,
3
,
4
))
num_predictions
=
flow_preds
.
shape
[
0
]
weights
=
gamma
**
torch
.
arange
(
num_predictions
-
1
,
-
1
,
-
1
).
to
(
flow_gt
.
device
)
flow_loss
=
(
abs_diff
*
weights
).
sum
()
return
flow_loss
class
InputPadder
:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def
__init__
(
self
,
dims
,
mode
=
"sintel"
):
self
.
ht
,
self
.
wd
=
dims
[
-
2
:]
pad_ht
=
(((
self
.
ht
//
8
)
+
1
)
*
8
-
self
.
ht
)
%
8
pad_wd
=
(((
self
.
wd
//
8
)
+
1
)
*
8
-
self
.
wd
)
%
8
if
mode
==
"sintel"
:
self
.
_pad
=
[
pad_wd
//
2
,
pad_wd
-
pad_wd
//
2
,
pad_ht
//
2
,
pad_ht
-
pad_ht
//
2
]
else
:
self
.
_pad
=
[
pad_wd
//
2
,
pad_wd
-
pad_wd
//
2
,
0
,
pad_ht
]
def
pad
(
self
,
*
inputs
):
return
[
F
.
pad
(
x
,
self
.
_pad
,
mode
=
"replicate"
)
for
x
in
inputs
]
def
unpad
(
self
,
x
):
ht
,
wd
=
x
.
shape
[
-
2
:]
c
=
[
self
.
_pad
[
2
],
ht
-
self
.
_pad
[
3
],
self
.
_pad
[
0
],
wd
-
self
.
_pad
[
1
]]
return
x
[...,
c
[
0
]
:
c
[
1
],
c
[
2
]
:
c
[
3
]]
def
_redefine_print
(
is_main
):
"""disables printing when not in main process"""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
force
=
kwargs
.
pop
(
"force"
,
False
)
if
is_main
or
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
setup_ddp
(
args
):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if
all
(
key
in
os
.
environ
for
key
in
(
"LOCAL_RANK"
,
"RANK"
,
"WORLD_SIZE"
)):
# if we're here, the script was called with torchrun. Otherwise
# these args will be set already by the run_with_submitit script
args
.
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
args
.
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
elif
"gpu"
in
args
:
# if we're here, the script was called by run_with_submitit.py
args
.
local_rank
=
args
.
gpu
else
:
raise
ValueError
(
r
"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯."
)
_redefine_print
(
is_main
=
(
args
.
rank
==
0
))
torch
.
cuda
.
set_device
(
args
.
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
init_method
=
args
.
dist_url
,
)
def
reduce_across_processes
(
val
):
t
=
torch
.
tensor
(
val
,
device
=
"cuda"
)
dist
.
barrier
()
dist
.
all_reduce
(
t
)
return
t
def
freeze_batch_norm
(
model
):
for
m
in
model
.
modules
():
if
isinstance
(
m
,
torch
.
nn
.
BatchNorm2d
):
m
.
eval
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment