Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4dd8b5cc
"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "fdaaa299a72eddbfff30a5f31fe2a643a2a4ca42"
Unverified
Commit
4dd8b5cc
authored
Dec 07, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 07, 2021
Browse files
Add training reference for optical flow models (#5027)
parent
47bd9620
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
616 additions
and
0 deletions
+616
-0
references/optical_flow/train.py
references/optical_flow/train.py
+334
-0
references/optical_flow/utils.py
references/optical_flow/utils.py
+282
-0
No files found.
references/optical_flow/train.py
0 → 100644
View file @
4dd8b5cc
import
argparse
import
warnings
from
pathlib
import
Path
import
torch
import
utils
from
presets
import
OpticalFlowPresetTrain
,
OpticalFlowPresetEval
from
torchvision.datasets
import
KittiFlow
,
FlyingChairs
,
FlyingThings3D
,
Sintel
,
HD1K
from
torchvision.models.optical_flow
import
raft_large
,
raft_small
def
get_train_dataset
(
stage
,
dataset_root
):
if
stage
==
"chairs"
:
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
(
368
,
496
),
min_scale
=
0.1
,
max_scale
=
1.0
,
do_flip
=
True
)
return
FlyingChairs
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
transforms
)
elif
stage
==
"things"
:
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
(
400
,
720
),
min_scale
=-
0.4
,
max_scale
=
0.8
,
do_flip
=
True
)
return
FlyingThings3D
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"both"
,
transforms
=
transforms
)
elif
stage
==
"sintel_SKH"
:
# S + K + H as from paper
crop_size
=
(
368
,
768
)
transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.2
,
max_scale
=
0.6
,
do_flip
=
True
)
things_clean
=
FlyingThings3D
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
transforms
)
sintel
=
Sintel
(
root
=
dataset_root
,
split
=
"train"
,
pass_name
=
"both"
,
transforms
=
transforms
)
kitti_transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.3
,
max_scale
=
0.5
,
do_flip
=
True
)
kitti
=
KittiFlow
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
kitti_transforms
)
hd1k_transforms
=
OpticalFlowPresetTrain
(
crop_size
=
crop_size
,
min_scale
=-
0.5
,
max_scale
=
0.2
,
do_flip
=
True
)
hd1k
=
HD1K
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
hd1k_transforms
)
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return
100
*
sintel
+
200
*
kitti
+
5
*
hd1k
+
things_clean
elif
stage
==
"kitti"
:
transforms
=
OpticalFlowPresetTrain
(
# resize and crop params
crop_size
=
(
288
,
960
),
min_scale
=-
0.2
,
max_scale
=
0.4
,
stretch_prob
=
0
,
# flip params
do_flip
=
False
,
# jitter params
brightness
=
0.3
,
contrast
=
0.3
,
saturation
=
0.3
,
hue
=
0.3
/
3.14
,
asymmetric_jitter_prob
=
0
,
)
return
KittiFlow
(
root
=
dataset_root
,
split
=
"train"
,
transforms
=
transforms
)
else
:
raise
ValueError
(
f
"Unknown stage
{
stage
}
"
)
@
torch
.
no_grad
()
def
_validate
(
model
,
args
,
val_dataset
,
*
,
padder_mode
,
num_flow_updates
=
None
,
batch_size
=
None
,
header
=
None
):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size
=
batch_size
or
args
.
batch_size
model
.
eval
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
,
shuffle
=
False
,
drop_last
=
True
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
sampler
=
sampler
,
batch_size
=
batch_size
,
pin_memory
=
True
,
num_workers
=
args
.
num_workers
,
)
num_flow_updates
=
num_flow_updates
or
args
.
num_flow_updates
def
inner_loop
(
blob
):
if
blob
[
0
].
dim
()
==
3
:
# input is not batched so we add an extra dim for consistency
blob
=
[
x
[
None
,
:,
:,
:]
if
x
is
not
None
else
None
for
x
in
blob
]
image1
,
image2
,
flow_gt
=
blob
[:
3
]
valid_flow_mask
=
None
if
len
(
blob
)
==
3
else
blob
[
-
1
]
image1
,
image2
=
image1
.
cuda
(),
image2
.
cuda
()
padder
=
utils
.
InputPadder
(
image1
.
shape
,
mode
=
padder_mode
)
image1
,
image2
=
padder
.
pad
(
image1
,
image2
)
flow_predictions
=
model
(
image1
,
image2
,
num_flow_updates
=
num_flow_updates
)
flow_pred
=
flow_predictions
[
-
1
]
flow_pred
=
padder
.
unpad
(
flow_pred
).
cpu
()
metrics
,
num_pixels_tot
=
utils
.
compute_metrics
(
flow_pred
,
flow_gt
,
valid_flow_mask
)
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for
name
in
(
"epe"
,
"1px"
,
"3px"
,
"5px"
,
"f1"
):
# f1 is called f1-all in paper
logger
.
meters
[
name
].
update
(
metrics
[
name
],
n
=
num_pixels_tot
)
logger
.
meters
[
"per_image_epe"
].
update
(
metrics
[
"epe"
],
n
=
batch_size
)
logger
=
utils
.
MetricLogger
()
for
meter_name
in
(
"epe"
,
"1px"
,
"3px"
,
"5px"
,
"per_image_epe"
,
"f1"
):
logger
.
add_meter
(
meter_name
,
fmt
=
"{global_avg:.4f}"
)
num_processed_samples
=
0
for
blob
in
logger
.
log_every
(
val_loader
,
header
=
header
,
print_freq
=
None
):
inner_loop
(
blob
)
num_processed_samples
+=
blob
[
0
].
shape
[
0
]
# batch size
num_processed_samples
=
utils
.
reduce_across_processes
(
num_processed_samples
)
print
(
f
"Batch-processed
{
num_processed_samples
}
/
{
len
(
val_dataset
)
}
samples. "
"Going to process the remaining samples individually, if any."
)
if
args
.
rank
==
0
:
# we only need to process the rest on a single worker
for
i
in
range
(
num_processed_samples
,
len
(
val_dataset
)):
inner_loop
(
val_dataset
[
i
])
logger
.
synchronize_between_processes
()
print
(
header
,
logger
)
def
validate
(
model
,
args
):
val_datasets
=
args
.
val_dataset
or
[]
for
name
in
val_datasets
:
if
name
==
"kitti"
:
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if
args
.
batch_size
!=
1
and
args
.
rank
==
0
:
warnings
.
warn
(
f
"Batch-size=
{
args
.
batch_size
}
was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)
val_dataset
=
KittiFlow
(
root
=
args
.
dataset_root
,
split
=
"train"
,
transforms
=
OpticalFlowPresetEval
())
_validate
(
model
,
args
,
val_dataset
,
num_flow_updates
=
24
,
padder_mode
=
"kitti"
,
header
=
"Kitti val"
,
batch_size
=
1
)
elif
name
==
"sintel"
:
for
pass_name
in
(
"clean"
,
"final"
):
val_dataset
=
Sintel
(
root
=
args
.
dataset_root
,
split
=
"train"
,
pass_name
=
pass_name
,
transforms
=
OpticalFlowPresetEval
()
)
_validate
(
model
,
args
,
val_dataset
,
num_flow_updates
=
32
,
padder_mode
=
"sintel"
,
header
=
f
"Sintel val
{
pass_name
}
"
,
)
else
:
warnings
.
warn
(
f
"Can't validate on
{
val_dataset
}
, skipping."
)
def
train_one_epoch
(
model
,
optimizer
,
scheduler
,
train_loader
,
logger
,
current_step
,
args
):
for
data_blob
in
logger
.
log_every
(
train_loader
):
optimizer
.
zero_grad
()
image1
,
image2
,
flow_gt
,
valid_flow_mask
=
(
x
.
cuda
()
for
x
in
data_blob
)
flow_predictions
=
model
(
image1
,
image2
,
num_flow_updates
=
args
.
num_flow_updates
)
loss
=
utils
.
sequence_loss
(
flow_predictions
,
flow_gt
,
valid_flow_mask
,
args
.
gamma
)
metrics
,
_
=
utils
.
compute_metrics
(
flow_predictions
[
-
1
],
flow_gt
,
valid_flow_mask
)
metrics
.
pop
(
"f1"
)
logger
.
update
(
loss
=
loss
,
**
metrics
)
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
max_norm
=
1
)
optimizer
.
step
()
scheduler
.
step
()
current_step
+=
1
if
current_step
==
args
.
num_steps
:
return
True
,
current_step
return
False
,
current_step
def
main
(
args
):
utils
.
setup_ddp
(
args
)
model
=
raft_small
()
if
args
.
small
else
raft_large
()
model
=
model
.
to
(
args
.
local_rank
)
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
local_rank
])
if
args
.
resume
is
not
None
:
d
=
torch
.
load
(
args
.
resume
,
map_location
=
"cpu"
)
model
.
load_state_dict
(
d
,
strict
=
True
)
if
args
.
train_dataset
is
None
:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
validate
(
model
,
args
)
return
print
(
f
"Parameter Count:
{
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
}
"
)
torch
.
backends
.
cudnn
.
benchmark
=
True
model
.
train
()
if
args
.
freeze_batch_norm
:
utils
.
freeze_batch_norm
(
model
.
module
)
train_dataset
=
get_train_dataset
(
args
.
train_dataset
,
args
.
dataset_root
)
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
,
shuffle
=
True
,
drop_last
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
sampler
=
sampler
,
batch_size
=
args
.
batch_size
,
pin_memory
=
True
,
num_workers
=
args
.
num_workers
,
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
eps
=
args
.
adamw_eps
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
OneCycleLR
(
optimizer
=
optimizer
,
max_lr
=
args
.
lr
,
total_steps
=
args
.
num_steps
+
100
,
pct_start
=
0.05
,
cycle_momentum
=
False
,
anneal_strategy
=
"linear"
,
)
logger
=
utils
.
MetricLogger
()
done
=
False
current_epoch
=
current_step
=
0
while
not
done
:
print
(
f
"EPOCH
{
current_epoch
}
"
)
sampler
.
set_epoch
(
current_epoch
)
# needed, otherwise the data loading order would be the same for all epochs
done
,
current_step
=
train_one_epoch
(
model
=
model
,
optimizer
=
optimizer
,
scheduler
=
scheduler
,
train_loader
=
train_loader
,
logger
=
logger
,
current_step
=
current_step
,
args
=
args
,
)
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print
(
f
"Epoch
{
current_epoch
}
done. "
,
logger
)
current_epoch
+=
1
if
args
.
rank
==
0
:
# TODO: Also save the optimizer and scheduler
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
_
{
current_epoch
}
.pth"
)
torch
.
save
(
model
.
state_dict
(),
Path
(
args
.
output_dir
)
/
f
"
{
args
.
name
}
.pth"
)
if
current_epoch
%
args
.
val_freq
==
0
or
done
:
validate
(
model
,
args
)
model
.
train
()
if
args
.
freeze_batch_norm
:
utils
.
freeze_batch_norm
(
model
.
module
)
def
get_args_parser
(
add_help
=
True
):
parser
=
argparse
.
ArgumentParser
(
add_help
=
add_help
,
description
=
"Train or evaluate an optical-flow model."
)
parser
.
add_argument
(
"--name"
,
default
=
"raft"
,
type
=
str
,
help
=
"The name of the experiment - determines the name of the files where weights are saved."
,
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"checkpoints"
,
type
=
str
,
help
=
"Output dir where checkpoints will be stored."
)
parser
.
add_argument
(
"--resume"
,
type
=
str
,
help
=
"A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model."
,
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
12
,
help
=
"Number of workers for the data loading part."
)
parser
.
add_argument
(
"--train-dataset"
,
type
=
str
,
help
=
"The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume)."
,
)
parser
.
add_argument
(
"--val-dataset"
,
type
=
str
,
nargs
=
"+"
,
help
=
"The dataset(s) to use for validation."
)
parser
.
add_argument
(
"--val-freq"
,
type
=
int
,
default
=
2
,
help
=
"Validate every X epochs"
)
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser
.
add_argument
(
"--num-steps"
,
type
=
int
,
default
=
100000
,
help
=
"The total number of steps (updates) to train."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.00002
,
help
=
"Learning rate for AdamW optimizer"
)
parser
.
add_argument
(
"--weight-decay"
,
type
=
float
,
default
=
0.00005
,
help
=
"Weight decay for AdamW optimizer"
)
parser
.
add_argument
(
"--adamw-eps"
,
type
=
float
,
default
=
1e-8
,
help
=
"eps value for AdamW optimizer"
)
parser
.
add_argument
(
"--freeze-batch-norm"
,
action
=
"store_true"
,
help
=
"Set BatchNorm modules of the model in eval mode."
)
parser
.
add_argument
(
"--small"
,
action
=
"store_true"
,
help
=
"Use the 'small' RAFT architecture."
)
parser
.
add_argument
(
"--num_flow_updates"
,
type
=
int
,
default
=
12
,
help
=
"number of updates (or 'iters') in the update operator of the model."
,
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
0.8
,
help
=
"exponential weighting for loss. Must be < 1."
)
parser
.
add_argument
(
"--dist-url"
,
default
=
"env://"
,
help
=
"URL used to set up distributed training"
)
parser
.
add_argument
(
"--dataset-root"
,
help
=
"Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets."
,
required
=
True
,
)
return
parser
if
__name__
==
"__main__"
:
args
=
get_args_parser
().
parse_args
()
Path
(
args
.
output_dir
).
mkdir
(
exist_ok
=
True
)
main
(
args
)
references/optical_flow/utils.py
0 → 100644
View file @
4dd8b5cc
import
datetime
import
os
import
time
from
collections
import
defaultdict
from
collections
import
deque
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def
__init__
(
self
,
window_size
=
20
,
fmt
=
"{median:.4f} ({global_avg:.4f})"
):
self
.
deque
=
deque
(
maxlen
=
window_size
)
self
.
total
=
0.0
self
.
count
=
0
self
.
fmt
=
fmt
def
update
(
self
,
value
,
n
=
1
):
self
.
deque
.
append
(
value
)
self
.
count
+=
n
self
.
total
+=
value
*
n
def
synchronize_between_processes
(
self
):
"""
Warning: does not synchronize the deque!
"""
t
=
reduce_across_processes
([
self
.
count
,
self
.
total
])
t
=
t
.
tolist
()
self
.
count
=
int
(
t
[
0
])
self
.
total
=
t
[
1
]
@
property
def
median
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
))
return
d
.
median
().
item
()
@
property
def
avg
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
),
dtype
=
torch
.
float32
)
return
d
.
mean
().
item
()
@
property
def
global_avg
(
self
):
return
self
.
total
/
self
.
count
@
property
def
max
(
self
):
return
max
(
self
.
deque
)
@
property
def
value
(
self
):
return
self
.
deque
[
-
1
]
def
__str__
(
self
):
return
self
.
fmt
.
format
(
median
=
self
.
median
,
avg
=
self
.
avg
,
global_avg
=
self
.
global_avg
,
max
=
self
.
max
,
value
=
self
.
value
)
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
def
update
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
item
()
assert
isinstance
(
v
,
(
float
,
int
))
self
.
meters
[
k
].
update
(
v
)
def
__getattr__
(
self
,
attr
):
if
attr
in
self
.
meters
:
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
for
meter
in
self
.
meters
.
values
():
meter
.
synchronize_between_processes
()
def
add_meter
(
self
,
name
,
**
kwargs
):
self
.
meters
[
name
]
=
SmoothedValue
(
**
kwargs
)
def
log_every
(
self
,
iterable
,
print_freq
=
5
,
header
=
None
):
i
=
0
if
not
header
:
header
=
""
start_time
=
time
.
time
()
end
=
time
.
time
()
iter_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
data_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
space_fmt
=
":"
+
str
(
len
(
str
(
len
(
iterable
))))
+
"d"
if
torch
.
cuda
.
is_available
():
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
,
"max mem: {memory:.0f}"
,
]
)
else
:
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
]
)
MB
=
1024.0
*
1024.0
for
obj
in
iterable
:
data_time
.
update
(
time
.
time
()
-
end
)
yield
obj
iter_time
.
update
(
time
.
time
()
-
end
)
if
print_freq
is
not
None
and
i
%
print_freq
==
0
:
eta_seconds
=
iter_time
.
global_avg
*
(
len
(
iterable
)
-
i
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
torch
.
cuda
.
is_available
():
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
memory
=
torch
.
cuda
.
max_memory_allocated
()
/
MB
,
)
)
else
:
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
)
)
)
i
+=
1
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
def
compute_metrics
(
flow_pred
,
flow_gt
,
valid_flow_mask
=
None
):
epe
=
((
flow_pred
-
flow_gt
)
**
2
).
sum
(
dim
=
1
).
sqrt
()
flow_norm
=
(
flow_gt
**
2
).
sum
(
dim
=
1
).
sqrt
()
if
valid_flow_mask
is
not
None
:
epe
=
epe
[
valid_flow_mask
]
flow_norm
=
flow_norm
[
valid_flow_mask
]
relative_epe
=
epe
/
flow_norm
metrics
=
{
"epe"
:
epe
.
mean
().
item
(),
"1px"
:
(
epe
<
1
).
float
().
mean
().
item
(),
"3px"
:
(
epe
<
3
).
float
().
mean
().
item
(),
"5px"
:
(
epe
<
5
).
float
().
mean
().
item
(),
"f1"
:
((
epe
>
3
)
&
(
relative_epe
>
0.05
)).
float
().
mean
().
item
()
*
100
,
}
return
metrics
,
epe
.
numel
()
def
sequence_loss
(
flow_preds
,
flow_gt
,
valid_flow_mask
,
gamma
=
0.8
,
max_flow
=
400
):
"""Loss function defined over sequence of flow predictions"""
if
gamma
>
1
:
raise
ValueError
(
f
"Gamma should be < 1, got
{
gamma
}
."
)
# exlude invalid pixels and extremely large diplacements
flow_norm
=
torch
.
sum
(
flow_gt
**
2
,
dim
=
1
).
sqrt
()
valid_flow_mask
=
valid_flow_mask
&
(
flow_norm
<
max_flow
)
valid_flow_mask
=
valid_flow_mask
[:,
None
,
:,
:]
flow_preds
=
torch
.
stack
(
flow_preds
)
# shape = (num_flow_updates, batch_size, 2, H, W)
abs_diff
=
(
flow_preds
-
flow_gt
).
abs
()
abs_diff
=
(
abs_diff
*
valid_flow_mask
).
mean
(
axis
=
(
1
,
2
,
3
,
4
))
num_predictions
=
flow_preds
.
shape
[
0
]
weights
=
gamma
**
torch
.
arange
(
num_predictions
-
1
,
-
1
,
-
1
).
to
(
flow_gt
.
device
)
flow_loss
=
(
abs_diff
*
weights
).
sum
()
return
flow_loss
class
InputPadder
:
"""Pads images such that dimensions are divisible by 8"""
# TODO: Ideally, this should be part of the eval transforms preset, instead
# of being part of the validation code. It's not obvious what a good
# solution would be, because we need to unpad the predicted flows according
# to the input images' size, and in some datasets (Kitti) images can have
# variable sizes.
def
__init__
(
self
,
dims
,
mode
=
"sintel"
):
self
.
ht
,
self
.
wd
=
dims
[
-
2
:]
pad_ht
=
(((
self
.
ht
//
8
)
+
1
)
*
8
-
self
.
ht
)
%
8
pad_wd
=
(((
self
.
wd
//
8
)
+
1
)
*
8
-
self
.
wd
)
%
8
if
mode
==
"sintel"
:
self
.
_pad
=
[
pad_wd
//
2
,
pad_wd
-
pad_wd
//
2
,
pad_ht
//
2
,
pad_ht
-
pad_ht
//
2
]
else
:
self
.
_pad
=
[
pad_wd
//
2
,
pad_wd
-
pad_wd
//
2
,
0
,
pad_ht
]
def
pad
(
self
,
*
inputs
):
return
[
F
.
pad
(
x
,
self
.
_pad
,
mode
=
"replicate"
)
for
x
in
inputs
]
def
unpad
(
self
,
x
):
ht
,
wd
=
x
.
shape
[
-
2
:]
c
=
[
self
.
_pad
[
2
],
ht
-
self
.
_pad
[
3
],
self
.
_pad
[
0
],
wd
-
self
.
_pad
[
1
]]
return
x
[...,
c
[
0
]
:
c
[
1
],
c
[
2
]
:
c
[
3
]]
def
_redefine_print
(
is_main
):
"""disables printing when not in main process"""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
force
=
kwargs
.
pop
(
"force"
,
False
)
if
is_main
or
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
setup_ddp
(
args
):
# Set the local_rank, rank, and world_size values as args fields
# This is done differently depending on how we're running the script. We
# currently support either torchrun or the custom run_with_submitit.py
# If you're confused (like I was), this might help a bit
# https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
if
all
(
key
in
os
.
environ
for
key
in
(
"LOCAL_RANK"
,
"RANK"
,
"WORLD_SIZE"
)):
# if we're here, the script was called with torchrun. Otherwise
# these args will be set already by the run_with_submitit script
args
.
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
args
.
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
elif
"gpu"
in
args
:
# if we're here, the script was called by run_with_submitit.py
args
.
local_rank
=
args
.
gpu
else
:
raise
ValueError
(
r
"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯."
)
_redefine_print
(
is_main
=
(
args
.
rank
==
0
))
torch
.
cuda
.
set_device
(
args
.
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
init_method
=
args
.
dist_url
,
)
def
reduce_across_processes
(
val
):
t
=
torch
.
tensor
(
val
,
device
=
"cuda"
)
dist
.
barrier
()
dist
.
all_reduce
(
t
)
return
t
def
freeze_batch_norm
(
model
):
for
m
in
model
.
modules
():
if
isinstance
(
m
,
torch
.
nn
.
BatchNorm2d
):
m
.
eval
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment