Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pytorch3d
Commits
cdd2142d
Unverified
Commit
cdd2142d
authored
Mar 21, 2022
by
Jeremy Reizenstein
Committed by
GitHub
Mar 21, 2022
Browse files
implicitron v0 (#1133)
Co-authored-by:
Jeremy Francis Reizenstein
<
bottler@users.noreply.github.com
>
parent
0e377c68
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5184 additions
and
0 deletions
+5184
-0
projects/implicitron_trainer/configs/repro_singleseq_srn.yaml
...ects/implicitron_trainer/configs/repro_singleseq_srn.yaml
+28
-0
projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml
...plicitron_trainer/configs/repro_singleseq_srn_noharm.yaml
+10
-0
projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml
.../implicitron_trainer/configs/repro_singleseq_srn_wce.yaml
+29
-0
projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml
...itron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml
+10
-0
projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml
...implicitron_trainer/configs/repro_singleseq_wce_base.yaml
+18
-0
projects/implicitron_trainer/experiment.py
projects/implicitron_trainer/experiment.py
+714
-0
projects/implicitron_trainer/visualize_reconstruction.py
projects/implicitron_trainer/visualize_reconstruction.py
+382
-0
pytorch3d/implicitron/dataset/dataloader_zoo.py
pytorch3d/implicitron/dataset/dataloader_zoo.py
+97
-0
pytorch3d/implicitron/dataset/dataset_zoo.py
pytorch3d/implicitron/dataset/dataset_zoo.py
+260
-0
pytorch3d/implicitron/dataset/implicitron_dataset.py
pytorch3d/implicitron/dataset/implicitron_dataset.py
+988
-0
pytorch3d/implicitron/dataset/scene_batch_sampler.py
pytorch3d/implicitron/dataset/scene_batch_sampler.py
+203
-0
pytorch3d/implicitron/dataset/types.py
pytorch3d/implicitron/dataset/types.py
+331
-0
pytorch3d/implicitron/dataset/utils.py
pytorch3d/implicitron/dataset/utils.py
+44
-0
pytorch3d/implicitron/dataset/visualize.py
pytorch3d/implicitron/dataset/visualize.py
+95
-0
pytorch3d/implicitron/eval_demo.py
pytorch3d/implicitron/eval_demo.py
+216
-0
pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py
...h3d/implicitron/evaluation/evaluate_new_view_synthesis.py
+649
-0
pytorch3d/implicitron/models/autodecoder.py
pytorch3d/implicitron/models/autodecoder.py
+172
-0
pytorch3d/implicitron/models/base.py
pytorch3d/implicitron/models/base.py
+883
-0
pytorch3d/implicitron/models/implicit_function/__init__.py
pytorch3d/implicitron/models/implicit_function/__init__.py
+5
-0
pytorch3d/implicitron/models/implicit_function/base.py
pytorch3d/implicitron/models/implicit_function/base.py
+50
-0
No files found.
projects/implicitron_trainer/configs/repro_singleseq_srn.yaml
0 → 100644
View file @
cdd2142d
defaults
:
-
repro_singleseq_base.yaml
-
_self_
generic_model_args
:
num_passes
:
1
chunk_size_grid
:
32000
view_pool
:
false
loss_weights
:
loss_rgb_mse
:
200.0
loss_prev_stage_rgb_mse
:
0.0
loss_mask_bce
:
1.0
loss_prev_stage_mask_bce
:
0.0
loss_autodecoder_norm
:
0.0
depth_neg_penalty
:
10000.0
raysampler_args
:
n_rays_per_image_sampled_from_mask
:
2048
min_depth
:
0.05
max_depth
:
0.05
scene_extent
:
0.0
n_pts_per_ray_training
:
1
n_pts_per_ray_evaluation
:
1
stratified_point_sampling_training
:
false
stratified_point_sampling_evaluation
:
false
renderer_class_type
:
LSTMRenderer
implicit_function_class_type
:
SRNImplicitFunction
solver_args
:
breed
:
adam
lr
:
5.0e-05
projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml
0 → 100644
View file @
cdd2142d
defaults
:
-
repro_singleseq_srn.yaml
-
_self_
generic_model_args
:
num_passes
:
1
implicit_function_SRNImplicitFunction_args
:
pixel_generator_args
:
n_harmonic_functions
:
0
raymarch_function_args
:
n_harmonic_functions
:
0
projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml
0 → 100644
View file @
cdd2142d
defaults
:
-
repro_singleseq_wce_base
-
repro_feat_extractor_normed.yaml
-
_self_
generic_model_args
:
num_passes
:
1
chunk_size_grid
:
32000
view_pool
:
true
loss_weights
:
loss_rgb_mse
:
200.0
loss_prev_stage_rgb_mse
:
0.0
loss_mask_bce
:
1.0
loss_prev_stage_mask_bce
:
0.0
loss_autodecoder_norm
:
0.0
depth_neg_penalty
:
10000.0
raysampler_args
:
n_rays_per_image_sampled_from_mask
:
2048
min_depth
:
0.05
max_depth
:
0.05
scene_extent
:
0.0
n_pts_per_ray_training
:
1
n_pts_per_ray_evaluation
:
1
stratified_point_sampling_training
:
false
stratified_point_sampling_evaluation
:
false
renderer_class_type
:
LSTMRenderer
implicit_function_class_type
:
SRNImplicitFunction
solver_args
:
breed
:
adam
lr
:
5.0e-05
projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml
0 → 100644
View file @
cdd2142d
defaults
:
-
repro_singleseq_srn_wce.yaml
-
_self_
generic_model_args
:
num_passes
:
1
implicit_function_SRNImplicitFunction_args
:
pixel_generator_args
:
n_harmonic_functions
:
0
raymarch_function_args
:
n_harmonic_functions
:
0
projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml
0 → 100644
View file @
cdd2142d
defaults
:
-
repro_singleseq_base
-
_self_
dataloader_args
:
batch_size
:
10
dataset_len
:
1000
dataset_len_val
:
1
num_workers
:
8
images_per_seq_options
:
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
10
projects/implicitron_trainer/experiment.py
0 → 100755
View file @
cdd2142d
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
""""
This file is the entry point for launching experiments with Implicitron.
Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
pass, visualizations and metric printing
Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:
```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```
To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.
Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
- Checkpoints:
Model, optimizer and stats are stored in the directory
named by the `exp_dir` key from the config file / CLI parameters.
- Stats
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.
"""
import
copy
import
json
import
logging
import
os
import
random
import
time
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
import
hydra
import
lpips
import
numpy
as
np
import
torch
import
tqdm
from
omegaconf
import
DictConfig
,
OmegaConf
from
packaging
import
version
from
pytorch3d.implicitron.dataset
import
utils
as
ds_utils
from
pytorch3d.implicitron.dataset.dataloader_zoo
import
dataloader_zoo
from
pytorch3d.implicitron.dataset.dataset_zoo
import
dataset_zoo
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
(
ImplicitronDataset
,
FrameData
,
)
from
pytorch3d.implicitron.evaluation
import
evaluate_new_view_synthesis
as
evaluate
from
pytorch3d.implicitron.models.base
import
EvaluationMode
,
GenericModel
from
pytorch3d.implicitron.tools
import
model_io
,
vis_utils
from
pytorch3d.implicitron.tools.config
import
(
get_default_args_field
,
remove_unused_components
,
)
from
pytorch3d.implicitron.tools.stats
import
Stats
from
pytorch3d.renderer.cameras
import
CamerasBase
logger
=
logging
.
getLogger
(
__name__
)
if
version
.
parse
(
hydra
.
__version__
)
<
version
.
Version
(
"1.1"
):
raise
ValueError
(
f
"Hydra version
{
hydra
.
__version__
}
is too old."
" (Implicitron requires version 1.1 or later.)"
)
try
:
# only makes sense in FAIR cluster
import
pytorch3d.implicitron.fair_cluster.slurm
# noqa: F401
except
ModuleNotFoundError
:
pass
def
init_model
(
cfg
:
DictConfig
,
force_load
:
bool
=
False
,
clear_stats
:
bool
=
False
,
load_model_only
:
bool
=
False
,
)
->
Tuple
[
GenericModel
,
Stats
,
Optional
[
Dict
[
str
,
Any
]]]:
"""
Returns an instance of `GenericModel`.
If `cfg.resume` is set or `force_load` is true,
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
will return the model with initial weights, unless `force_load` is passed,
in which case a FileNotFoundError is raised.
Args:
force_load: If true, force load model from checkpoint even if
cfg.resume is false.
clear_stats: If true, clear the stats object loaded from checkpoint
load_model_only: If true, load only the model weights from checkpoint
and do not load the state of the optimizer and stats.
Returns:
model: The model with optionally loaded weights from checkpoint
stats: The stats structure (optionally loaded from checkpoint)
optimizer_state: The optimizer state dict containing
`state` and `param_groups` keys (optionally loaded from checkpoint)
Raise:
FileNotFoundError if `force_load` is passed but checkpoint is not found.
"""
# Initialize the model
if
cfg
.
architecture
==
"generic"
:
model
=
GenericModel
(
**
cfg
.
generic_model_args
)
else
:
raise
ValueError
(
f
"No such arch
{
cfg
.
architecture
}
."
)
# Determine the network outputs that should be logged
if
hasattr
(
model
,
"log_vars"
):
log_vars
=
copy
.
deepcopy
(
list
(
model
.
log_vars
))
else
:
log_vars
=
[
"objective"
]
visdom_env_charts
=
vis_utils
.
get_visdom_env
(
cfg
)
+
"_charts"
# Init the stats struct
stats
=
Stats
(
log_vars
,
visdom_env
=
visdom_env_charts
,
verbose
=
False
,
visdom_server
=
cfg
.
visdom_server
,
visdom_port
=
cfg
.
visdom_port
,
)
# Retrieve the last checkpoint
if
cfg
.
resume_epoch
>
0
:
model_path
=
model_io
.
get_checkpoint
(
cfg
.
exp_dir
,
cfg
.
resume_epoch
)
else
:
model_path
=
model_io
.
find_last_checkpoint
(
cfg
.
exp_dir
)
optimizer_state
=
None
if
model_path
is
not
None
:
logger
.
info
(
"found previous model %s"
%
model_path
)
if
force_load
or
cfg
.
resume
:
logger
.
info
(
" -> resuming"
)
if
load_model_only
:
model_state_dict
=
torch
.
load
(
model_io
.
get_model_path
(
model_path
))
stats_load
,
optimizer_state
=
None
,
None
else
:
model_state_dict
,
stats_load
,
optimizer_state
=
model_io
.
load_model
(
model_path
)
# Determine if stats should be reset
if
not
clear_stats
:
if
stats_load
is
None
:
logger
.
info
(
"
\n\n\n\n
CORRUPT STATS -> clearing stats
\n\n\n\n
"
)
last_epoch
=
model_io
.
parse_epoch_from_model_path
(
model_path
)
logger
.
info
(
f
"Estimated resume epoch =
{
last_epoch
}
"
)
# Reset the stats struct
for
_
in
range
(
last_epoch
+
1
):
stats
.
new_epoch
()
assert
last_epoch
==
stats
.
epoch
else
:
stats
=
stats_load
# Update stats properties incase it was reset on load
stats
.
visdom_env
=
visdom_env_charts
stats
.
visdom_server
=
cfg
.
visdom_server
stats
.
visdom_port
=
cfg
.
visdom_port
stats
.
plot_file
=
os
.
path
.
join
(
cfg
.
exp_dir
,
"train_stats.pdf"
)
stats
.
synchronize_logged_vars
(
log_vars
)
else
:
logger
.
info
(
" -> clearing stats"
)
try
:
# TODO: fix on creation of the buffers
# after the hack above, this will not pass in most cases
# ... but this is fine for now
model
.
load_state_dict
(
model_state_dict
,
strict
=
True
)
except
RuntimeError
as
e
:
logger
.
error
(
e
)
logger
.
info
(
"Cant load state dict in strict mode! -> trying non-strict"
)
model
.
load_state_dict
(
model_state_dict
,
strict
=
False
)
model
.
log_vars
=
log_vars
else
:
logger
.
info
(
" -> but not resuming -> starting from scratch"
)
elif
force_load
:
raise
FileNotFoundError
(
f
"Cannot find a checkpoint in
{
cfg
.
exp_dir
}
!"
)
return
model
,
stats
,
optimizer_state
def
init_optimizer
(
model
:
GenericModel
,
optimizer_state
:
Optional
[
Dict
[
str
,
Any
]],
last_epoch
:
int
,
breed
:
bool
=
"adam"
,
weight_decay
:
float
=
0.0
,
lr_policy
:
str
=
"multistep"
,
lr
:
float
=
0.0005
,
gamma
:
float
=
0.1
,
momentum
:
float
=
0.9
,
betas
:
Tuple
[
float
]
=
(
0.9
,
0.999
),
milestones
:
tuple
=
(),
max_epochs
:
int
=
1000
,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if
hasattr
(
model
,
"_get_param_groups"
):
# use the model function
p_groups
=
model
.
_get_param_groups
(
lr
,
wd
=
weight_decay
)
else
:
allprm
=
[
prm
for
prm
in
model
.
parameters
()
if
prm
.
requires_grad
]
p_groups
=
[{
"params"
:
allprm
,
"lr"
:
lr
}]
# Intialize the optimizer
if
breed
==
"sgd"
:
optimizer
=
torch
.
optim
.
SGD
(
p_groups
,
lr
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
elif
breed
==
"adagrad"
:
optimizer
=
torch
.
optim
.
Adagrad
(
p_groups
,
lr
=
lr
,
weight_decay
=
weight_decay
)
elif
breed
==
"adam"
:
optimizer
=
torch
.
optim
.
Adam
(
p_groups
,
lr
=
lr
,
betas
=
betas
,
weight_decay
=
weight_decay
)
else
:
raise
ValueError
(
"no such solver type %s"
%
breed
)
logger
.
info
(
" -> solver type = %s"
%
breed
)
# Load state from checkpoint
if
optimizer_state
is
not
None
:
logger
.
info
(
" -> setting loaded optimizer state"
)
optimizer
.
load_state_dict
(
optimizer_state
)
# Initialize the learning rate scheduler
if
lr_policy
==
"multistep"
:
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
milestones
,
gamma
=
gamma
,
)
else
:
raise
ValueError
(
"no such lr policy %s"
%
lr_policy
)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for
_
in
range
(
last_epoch
):
scheduler
.
step
()
# Add the max epochs here
scheduler
.
max_epochs
=
max_epochs
optimizer
.
zero_grad
()
return
optimizer
,
scheduler
def
trainvalidate
(
model
,
stats
,
epoch
,
loader
,
optimizer
,
validation
,
bp_var
:
str
=
"objective"
,
metric_print_interval
:
int
=
5
,
visualize_interval
:
int
=
100
,
visdom_env_root
:
str
=
"trainvalidate"
,
clip_grad
:
float
=
0.0
,
device
:
str
=
"cuda:0"
,
**
kwargs
,
)
->
None
:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
model: The model module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
epoch: The index of the current epoch
loader: The dataloader to use for the loop
optimizer: The optimizer module optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env_root: The name of the visdom environment to use for plotting
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
device: The device on which to run the model.
Returns:
None
"""
if
validation
:
model
.
eval
()
trainmode
=
"val"
else
:
model
.
train
()
trainmode
=
"train"
t_start
=
time
.
time
()
# get the visdom env name
visdom_env_imgs
=
visdom_env_root
+
"_images_"
+
trainmode
viz
=
vis_utils
.
get_visdom_connection
(
server
=
stats
.
visdom_server
,
port
=
stats
.
visdom_port
,
)
# Iterate through the batches
n_batches
=
len
(
loader
)
for
it
,
batch
in
enumerate
(
loader
):
last_iter
=
it
==
n_batches
-
1
# move to gpu where possible (in place)
net_input
=
batch
.
to
(
device
)
# run the forward pass
if
not
validation
:
optimizer
.
zero_grad
()
preds
=
model
(
**
{
**
net_input
,
"evaluation_mode"
:
EvaluationMode
.
TRAINING
})
else
:
with
torch
.
no_grad
():
preds
=
model
(
**
{
**
net_input
,
"evaluation_mode"
:
EvaluationMode
.
EVALUATION
}
)
# make sure we dont overwrite something
assert
all
(
k
not
in
preds
for
k
in
net_input
.
keys
())
# merge everything into one big dict
preds
.
update
(
net_input
)
# update the stats logger
stats
.
update
(
preds
,
time_start
=
t_start
,
stat_set
=
trainmode
)
assert
stats
.
it
[
trainmode
]
==
it
,
"inconsistent stat iteration number!"
# print textual status update
if
it
%
metric_print_interval
==
0
or
last_iter
:
stats
.
print
(
stat_set
=
trainmode
,
max_it
=
n_batches
)
# visualize results
if
visualize_interval
>
0
and
it
%
visualize_interval
==
0
:
prefix
=
f
"e
{
stats
.
epoch
}
_it
{
stats
.
it
[
trainmode
]
}
"
model
.
visualize
(
viz
,
visdom_env_imgs
,
preds
,
prefix
,
)
# optimizer step
if
not
validation
:
loss
=
preds
[
bp_var
]
assert
torch
.
isfinite
(
loss
).
all
(),
"Non-finite loss!"
# backprop
loss
.
backward
()
if
clip_grad
>
0.0
:
# Optionally clip the gradient norms.
total_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
model
.
parameters
(),
clip_grad
)
if
total_norm
>
clip_grad
:
logger
.
info
(
f
"Clipping gradient:
{
total_norm
}
"
+
f
" with coef
{
clip_grad
/
total_norm
}
."
)
optimizer
.
step
()
def
run_training
(
cfg
:
DictConfig
,
device
:
str
=
"cpu"
):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# set the debug mode
if
cfg
.
detect_anomaly
:
logger
.
info
(
"Anomaly detection!"
)
torch
.
autograd
.
set_detect_anomaly
(
cfg
.
detect_anomaly
)
# create the output folder
os
.
makedirs
(
cfg
.
exp_dir
,
exist_ok
=
True
)
_seed_all_random_engines
(
cfg
.
seed
)
remove_unused_components
(
cfg
)
# dump the exp config to the exp dir
try
:
cfg_filename
=
os
.
path
.
join
(
cfg
.
exp_dir
,
"expconfig.yaml"
)
OmegaConf
.
save
(
config
=
cfg
,
f
=
cfg_filename
)
except
PermissionError
:
warnings
.
warn
(
"Cant dump config due to insufficient permissions!"
)
# setup datasets
datasets
=
dataset_zoo
(
**
cfg
.
dataset_args
)
cfg
.
dataloader_args
[
"dataset_name"
]
=
cfg
.
dataset_args
[
"dataset_name"
]
dataloaders
=
dataloader_zoo
(
datasets
,
**
cfg
.
dataloader_args
)
# init the model
model
,
stats
,
optimizer_state
=
init_model
(
cfg
)
start_epoch
=
stats
.
epoch
+
1
# move model to gpu
model
.
to
(
device
)
# only run evaluation on the test dataloader
if
cfg
.
eval_only
:
_eval_and_dump
(
cfg
,
datasets
,
dataloaders
,
model
,
stats
,
device
=
device
)
return
# init the optimizer
optimizer
,
scheduler
=
init_optimizer
(
model
,
optimizer_state
=
optimizer_state
,
last_epoch
=
start_epoch
,
**
cfg
.
solver_args
,
)
# check the scheduler and stats have been initialized correctly
assert
scheduler
.
last_epoch
==
stats
.
epoch
+
1
assert
scheduler
.
last_epoch
==
start_epoch
past_scheduler_lrs
=
[]
# loop through epochs
for
epoch
in
range
(
start_epoch
,
cfg
.
solver_args
.
max_epochs
):
# automatic new_epoch and plotting of stats at every epoch start
with
stats
:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines
(
cfg
.
seed
+
epoch
)
cur_lr
=
float
(
scheduler
.
get_last_lr
()[
-
1
])
logger
.
info
(
f
"scheduler lr =
{
cur_lr
:
1.2
e
}
"
)
past_scheduler_lrs
.
append
(
cur_lr
)
# train loop
trainvalidate
(
model
,
stats
,
epoch
,
dataloaders
[
"train"
],
optimizer
,
False
,
visdom_env_root
=
vis_utils
.
get_visdom_env
(
cfg
),
device
=
device
,
**
cfg
,
)
# val loop (optional)
if
"val"
in
dataloaders
and
epoch
%
cfg
.
validation_interval
==
0
:
trainvalidate
(
model
,
stats
,
epoch
,
dataloaders
[
"val"
],
optimizer
,
True
,
visdom_env_root
=
vis_utils
.
get_visdom_env
(
cfg
),
device
=
device
,
**
cfg
,
)
# eval loop (optional)
if
(
"test"
in
dataloaders
and
cfg
.
test_interval
>
0
and
epoch
%
cfg
.
test_interval
==
0
):
run_eval
(
cfg
,
model
,
stats
,
dataloaders
[
"test"
],
device
=
device
)
assert
stats
.
epoch
==
epoch
,
"inconsistent stats!"
# delete previous models if required
# save model
if
cfg
.
store_checkpoints
:
if
cfg
.
store_checkpoints_purge
>
0
:
for
prev_epoch
in
range
(
epoch
-
cfg
.
store_checkpoints_purge
):
model_io
.
purge_epoch
(
cfg
.
exp_dir
,
prev_epoch
)
outfile
=
model_io
.
get_checkpoint
(
cfg
.
exp_dir
,
epoch
)
model_io
.
safe_save_model
(
model
,
stats
,
outfile
,
optimizer
=
optimizer
)
scheduler
.
step
()
new_lr
=
float
(
scheduler
.
get_last_lr
()[
-
1
])
if
new_lr
!=
cur_lr
:
logger
.
info
(
f
"LR change!
{
cur_lr
}
->
{
new_lr
}
"
)
if
cfg
.
test_when_finished
:
_eval_and_dump
(
cfg
,
datasets
,
dataloaders
,
model
,
stats
,
device
=
device
)
def
_eval_and_dump
(
cfg
,
datasets
,
dataloaders
,
model
,
stats
,
device
):
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if
"test"
not
in
dataloaders
:
raise
ValueError
(
'Dataloaders have to contain the "test" entry for eval!'
)
eval_task
=
cfg
.
dataset_args
[
"dataset_name"
].
split
(
"_"
)[
-
1
]
all_source_cameras
=
(
_get_all_source_cameras
(
datasets
[
"train"
])
if
eval_task
==
"singlesequence"
else
None
)
results
=
run_eval
(
cfg
,
model
,
all_source_cameras
,
dataloaders
[
"test"
],
eval_task
,
device
=
device
)
# add the evaluation epoch to the results
for
r
in
results
:
r
[
"eval_epoch"
]
=
int
(
stats
.
epoch
)
logger
.
info
(
"Evaluation results"
)
evaluate
.
pretty_print_nvs_metrics
(
results
)
with
open
(
os
.
path
.
join
(
cfg
.
exp_dir
,
"results_test.json"
),
"w"
)
as
f
:
json
.
dump
(
results
,
f
)
def
_get_eval_frame_data
(
frame_data
):
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval
=
copy
.
deepcopy
(
frame_data
)
is_known
=
ds_utils
.
is_known_frame
(
frame_data
.
frame_type
).
type_as
(
frame_data
.
image_rgb
)[:,
None
,
None
,
None
]
for
k
in
(
"image_rgb"
,
"depth_map"
,
"fg_probability"
,
"mask_crop"
):
value_masked
=
getattr
(
frame_data_for_eval
,
k
).
clone
()
*
is_known
setattr
(
frame_data_for_eval
,
k
,
value_masked
)
return
frame_data_for_eval
def
run_eval
(
cfg
,
model
,
all_source_cameras
,
loader
,
task
,
device
):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model
=
lpips
.
LPIPS
(
net
=
"vgg"
)
lpips_model
=
lpips_model
.
to
(
device
)
model
.
eval
()
per_batch_eval_results
=
[]
logger
.
info
(
"Evaluating model ..."
)
for
frame_data
in
tqdm
.
tqdm
(
loader
):
frame_data
=
frame_data
.
to
(
device
)
# mask out the unknown images so that the model does not see them
frame_data_for_eval
=
_get_eval_frame_data
(
frame_data
)
with
torch
.
no_grad
():
preds
=
model
(
**
{
**
frame_data_for_eval
,
"evaluation_mode"
:
EvaluationMode
.
EVALUATION
}
)
nvs_prediction
=
copy
.
deepcopy
(
preds
[
"nvs_prediction"
])
per_batch_eval_results
.
append
(
evaluate
.
eval_batch
(
frame_data
,
nvs_prediction
,
bg_color
=
"black"
,
lpips_model
=
lpips_model
,
source_cameras
=
all_source_cameras
,
)
)
_
,
category_result
=
evaluate
.
summarize_nvs_eval_results
(
per_batch_eval_results
,
task
)
return
category_result
[
"results"
]
def
_get_all_source_cameras
(
dataset
:
ImplicitronDataset
,
num_workers
:
int
=
8
,
)
->
CamerasBase
:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data
=
next
(
iter
(
torch
.
utils
.
data
.
DataLoader
(
dataset
,
shuffle
=
False
,
batch_size
=
len
(
dataset
),
num_workers
=
num_workers
,
collate_fn
=
FrameData
.
collate
,
)
)
)
is_source
=
ds_utils
.
is_known_frame
(
all_frame_data
.
frame_type
)
source_cameras
=
all_frame_data
.
camera
[
torch
.
where
(
is_source
)[
0
]]
return
source_cameras
def
_seed_all_random_engines
(
seed
:
int
):
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
@
dataclass
(
eq
=
False
)
class
ExperimentConfig
:
generic_model_args
:
DictConfig
=
get_default_args_field
(
GenericModel
)
solver_args
:
DictConfig
=
get_default_args_field
(
init_optimizer
)
dataset_args
:
DictConfig
=
get_default_args_field
(
dataset_zoo
)
dataloader_args
:
DictConfig
=
get_default_args_field
(
dataloader_zoo
)
architecture
:
str
=
"generic"
detect_anomaly
:
bool
=
False
eval_only
:
bool
=
False
exp_dir
:
str
=
"./data/default_experiment/"
exp_idx
:
int
=
0
gpu_idx
:
int
=
0
metric_print_interval
:
int
=
5
resume
:
bool
=
True
resume_epoch
:
int
=
-
1
seed
:
int
=
0
store_checkpoints
:
bool
=
True
store_checkpoints_purge
:
int
=
1
test_interval
:
int
=
-
1
test_when_finished
:
bool
=
False
validation_interval
:
int
=
1
visdom_env
:
str
=
""
visdom_port
:
int
=
8097
visdom_server
:
str
=
"http://127.0.0.1"
visualize_interval
:
int
=
1000
clip_grad
:
float
=
0.0
hydra
:
dict
=
field
(
default_factory
=
lambda
:
{
"run"
:
{
"dir"
:
"."
},
# Make hydra not change the working dir.
"output_subdir"
:
None
,
# disable storing the .hydra logs
}
)
cs
=
hydra
.
core
.
config_store
.
ConfigStore
.
instance
()
cs
.
store
(
name
=
"default_config"
,
node
=
ExperimentConfig
)
@
hydra
.
main
(
config_path
=
"./configs/"
,
config_name
=
"default_config"
)
def
experiment
(
cfg
:
DictConfig
)
->
None
:
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
cfg
.
gpu_idx
)
# Set the device
device
=
"cpu"
if
torch
.
cuda
.
is_available
()
and
cfg
.
gpu_idx
<
torch
.
cuda
.
device_count
():
device
=
f
"cuda:
{
cfg
.
gpu_idx
}
"
logger
.
info
(
f
"Running experiment on device:
{
device
}
"
)
run_training
(
cfg
,
device
)
if
__name__
==
"__main__"
:
experiment
()
projects/implicitron_trainer/visualize_reconstruction.py
0 → 100644
View file @
cdd2142d
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
import
math
import
os
import
random
import
sys
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
Fu
from
experiment
import
init_model
from
omegaconf
import
OmegaConf
from
pytorch3d.implicitron.dataset.dataset_zoo
import
dataset_zoo
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
(
FrameData
,
ImplicitronDataset
,
)
from
pytorch3d.implicitron.dataset.utils
import
is_train_frame
from
pytorch3d.implicitron.models.base
import
EvaluationMode
from
pytorch3d.implicitron.tools.configurable
import
get_default_args
from
pytorch3d.implicitron.tools.eval_video_trajectory
import
(
generate_eval_video_cameras
,
)
from
pytorch3d.implicitron.tools.video_writer
import
VideoWriter
from
pytorch3d.implicitron.tools.vis_utils
import
(
get_visdom_connection
,
make_depth_image
,
)
from
tqdm
import
tqdm
def
render_sequence
(
dataset
:
ImplicitronDataset
,
sequence_name
:
str
,
model
:
torch
.
nn
.
Module
,
video_path
,
n_eval_cameras
=
40
,
fps
=
20
,
max_angle
=
2
*
math
.
pi
,
trajectory_type
=
"circular_lsq_fit"
,
trajectory_scale
=
1.1
,
scene_center
=
(
0.0
,
0.0
,
0.0
),
up
=
(
0.0
,
-
1.0
,
0.0
),
traj_offset
=
0.0
,
n_source_views
=
9
,
viz_env
=
"debug"
,
visdom_show_preds
=
False
,
visdom_server
=
"http://127.0.0.1"
,
visdom_port
=
8097
,
num_workers
=
10
,
seed
=
None
,
video_resize
=
None
,
):
if
seed
is
None
:
seed
=
hash
(
sequence_name
)
print
(
f
"Loading all data of sequence '
{
sequence_name
}
'."
)
seq_idx
=
dataset
.
seq_to_idx
[
sequence_name
]
train_data
=
_load_whole_dataset
(
dataset
,
seq_idx
,
num_workers
=
num_workers
)
assert
all
(
train_data
.
sequence_name
[
0
]
==
sn
for
sn
in
train_data
.
sequence_name
)
sequence_set_name
=
"train"
if
is_train_frame
(
train_data
.
frame_type
)[
0
]
else
"test"
print
(
f
"Sequence set =
{
sequence_set_name
}
."
)
train_cameras
=
train_data
.
camera
time
=
torch
.
linspace
(
0
,
max_angle
,
n_eval_cameras
+
1
)[:
n_eval_cameras
]
test_cameras
=
generate_eval_video_cameras
(
train_cameras
,
time
=
time
,
n_eval_cams
=
n_eval_cameras
,
trajectory_type
=
trajectory_type
,
trajectory_scale
=
trajectory_scale
,
scene_center
=
scene_center
,
up
=
up
,
focal_length
=
None
,
principal_point
=
torch
.
zeros
(
n_eval_cameras
,
2
),
traj_offset_canonical
=
[
0.0
,
0.0
,
traj_offset
],
)
# sample the source views reproducibly
with
torch
.
random
.
fork_rng
():
torch
.
manual_seed
(
seed
)
source_views_i
=
torch
.
randperm
(
len
(
seq_idx
))[:
n_source_views
]
# add the first dummy view that will get replaced with the target camera
source_views_i
=
Fu
.
pad
(
source_views_i
,
[
1
,
0
])
source_views
=
[
seq_idx
[
i
]
for
i
in
source_views_i
.
tolist
()]
batch
=
_load_whole_dataset
(
dataset
,
source_views
,
num_workers
=
num_workers
)
assert
all
(
batch
.
sequence_name
[
0
]
==
sn
for
sn
in
batch
.
sequence_name
)
preds_total
=
[]
for
n
in
tqdm
(
range
(
n_eval_cameras
),
total
=
n_eval_cameras
):
# set the first batch camera to the target camera
for
k
in
(
"R"
,
"T"
,
"focal_length"
,
"principal_point"
):
getattr
(
batch
.
camera
,
k
)[
0
]
=
getattr
(
test_cameras
[
n
],
k
)
# Move to cuda
net_input
=
batch
.
cuda
()
with
torch
.
no_grad
():
preds
=
model
(
**
{
**
net_input
,
"evaluation_mode"
:
EvaluationMode
.
EVALUATION
})
# make sure we dont overwrite something
assert
all
(
k
not
in
preds
for
k
in
net_input
.
keys
())
preds
.
update
(
net_input
)
# merge everything into one big dict
# Render the predictions to images
rendered_pred
=
images_from_preds
(
preds
)
preds_total
.
append
(
rendered_pred
)
# show the preds every 5% of the export iterations
if
visdom_show_preds
and
(
n
%
max
(
n_eval_cameras
//
20
,
1
)
==
0
or
n
==
n_eval_cameras
-
1
):
viz
=
get_visdom_connection
(
server
=
visdom_server
,
port
=
visdom_port
)
show_predictions
(
preds_total
,
sequence_name
=
batch
.
sequence_name
[
0
],
viz
=
viz
,
viz_env
=
viz_env
,
)
print
(
f
"Exporting videos for sequence
{
sequence_name
}
..."
)
generate_prediction_videos
(
preds_total
,
sequence_name
=
batch
.
sequence_name
[
0
],
viz
=
viz
,
viz_env
=
viz_env
,
fps
=
fps
,
video_path
=
video_path
,
resize
=
video_resize
,
)
def
_load_whole_dataset
(
dataset
,
idx
,
num_workers
=
10
):
load_all_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
torch
.
utils
.
data
.
Subset
(
dataset
,
idx
),
batch_size
=
len
(
idx
),
num_workers
=
num_workers
,
shuffle
=
False
,
collate_fn
=
FrameData
.
collate
,
)
return
next
(
iter
(
load_all_dataloader
))
def
images_from_preds
(
preds
):
imout
=
{}
for
k
in
(
"image_rgb"
,
"images_render"
,
"fg_probability"
,
"masks_render"
,
"depths_render"
,
"depth_map"
,
"_all_source_images"
,
):
if
k
==
"_all_source_images"
and
"image_rgb"
in
preds
:
src_ims
=
preds
[
"image_rgb"
][
1
:].
cpu
().
detach
().
clone
()
v
=
_stack_images
(
src_ims
,
None
)[
None
]
else
:
if
k
not
in
preds
or
preds
[
k
]
is
None
:
print
(
f
"cant show
{
k
}
"
)
continue
v
=
preds
[
k
].
cpu
().
detach
().
clone
()
if
k
.
startswith
(
"depth"
):
mask_resize
=
Fu
.
interpolate
(
preds
[
"masks_render"
],
size
=
preds
[
k
].
shape
[
2
:],
mode
=
"nearest"
,
)
v
=
make_depth_image
(
preds
[
k
],
mask_resize
)
if
v
.
shape
[
1
]
==
1
:
v
=
v
.
repeat
(
1
,
3
,
1
,
1
)
imout
[
k
]
=
v
.
detach
().
cpu
()
return
imout
def
_stack_images
(
ims
,
size
):
ba
=
ims
.
shape
[
0
]
H
=
int
(
np
.
ceil
(
np
.
sqrt
(
ba
)))
W
=
H
n_add
=
H
*
W
-
ba
if
n_add
>
0
:
ims
=
torch
.
cat
((
ims
,
torch
.
zeros_like
(
ims
[:
1
]).
repeat
(
n_add
,
1
,
1
,
1
)))
ims
=
ims
.
view
(
H
,
W
,
*
ims
.
shape
[
1
:])
cated
=
torch
.
cat
([
torch
.
cat
(
list
(
row
),
dim
=
2
)
for
row
in
ims
],
dim
=
1
)
if
size
is
not
None
:
cated
=
Fu
.
interpolate
(
cated
[
None
],
size
=
size
,
mode
=
"bilinear"
)[
0
]
return
cated
.
clamp
(
0.0
,
1.0
)
def
show_predictions
(
preds
,
sequence_name
,
viz
,
viz_env
=
"visualizer"
,
predicted_keys
=
(
"images_render"
,
"masks_render"
,
"depths_render"
,
"_all_source_images"
,
),
n_samples
=
10
,
one_image_width
=
200
,
):
"""Given a list of predictions visualize them into a single image using visdom."""
assert
isinstance
(
preds
,
list
)
pred_all
=
[]
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples
=
min
(
n_samples
,
len
(
preds
))
pred_idx
=
sorted
(
random
.
sample
(
list
(
range
(
len
(
preds
))),
n_samples
))
for
predi
in
pred_idx
:
# Make the concatentation for the same camera vertically
pred_all
.
append
(
torch
.
cat
(
[
torch
.
nn
.
functional
.
interpolate
(
preds
[
predi
][
k
].
cpu
(),
scale_factor
=
one_image_width
/
preds
[
predi
][
k
].
shape
[
3
],
mode
=
"bilinear"
,
).
clamp
(
0.0
,
1.0
)
for
k
in
predicted_keys
],
dim
=
2
,
)
)
# Concatenate the images horizontally
pred_all_cat
=
torch
.
cat
(
pred_all
,
dim
=
3
)[
0
]
viz
.
image
(
pred_all_cat
,
win
=
"show_predictions"
,
env
=
viz_env
,
opts
=
{
"title"
:
f
"pred_
{
sequence_name
}
"
},
)
def
generate_prediction_videos
(
preds
,
sequence_name
,
viz
,
viz_env
=
"visualizer"
,
predicted_keys
=
(
"images_render"
,
"masks_render"
,
"depths_render"
,
"_all_source_images"
,
),
fps
=
20
,
video_path
=
"/tmp/video"
,
resize
=
None
,
):
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
assert
isinstance
(
preds
,
list
)
# make sure the target video directory exists
os
.
makedirs
(
os
.
path
.
dirname
(
video_path
),
exist_ok
=
True
)
# init a video writer for each predicted key
vws
=
{}
for
k
in
predicted_keys
:
vws
[
k
]
=
VideoWriter
(
out_path
=
f
"
{
video_path
}
_
{
sequence_name
}
_
{
k
}
.mp4"
,
fps
=
fps
)
for
rendered_pred
in
tqdm
(
preds
):
for
k
in
predicted_keys
:
vws
[
k
].
write_frame
(
rendered_pred
[
k
][
0
].
detach
().
cpu
().
numpy
(),
resize
=
resize
,
)
for
k
in
predicted_keys
:
vws
[
k
].
get_video
(
quiet
=
True
)
print
(
f
"Generated
{
vws
[
k
].
out_path
}
."
)
viz
.
video
(
videofile
=
vws
[
k
].
out_path
,
env
=
viz_env
,
win
=
k
,
# we reuse the same window otherwise visdom dies
opts
=
{
"title"
:
sequence_name
+
" "
+
k
},
)
def
export_scenes
(
exp_dir
:
str
=
""
,
restrict_sequence_name
:
Optional
[
str
]
=
None
,
output_directory
:
Optional
[
str
]
=
None
,
render_size
:
Tuple
[
int
,
int
]
=
(
512
,
512
),
video_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
split
:
str
=
"train"
,
# train | test
n_source_views
:
int
=
9
,
n_eval_cameras
:
int
=
40
,
visdom_server
=
"http://127.0.0.1"
,
visdom_port
=
8097
,
visdom_show_preds
:
bool
=
False
,
visdom_env
:
Optional
[
str
]
=
None
,
gpu_idx
:
int
=
0
,
):
# In case an output directory is specified use it. If no output_directory
# is specified create a vis folder inside the experiment directory
if
output_directory
is
None
:
output_directory
=
os
.
path
.
join
(
exp_dir
,
"vis"
)
else
:
output_directory
=
output_directory
if
not
os
.
path
.
exists
(
output_directory
):
os
.
makedirs
(
output_directory
)
# Set the random seeds
torch
.
manual_seed
(
0
)
np
.
random
.
seed
(
0
)
# Get the config from the experiment_directory,
# and overwrite relevant fields
config
=
_get_config_from_experiment_directory
(
exp_dir
)
config
.
gpu_idx
=
gpu_idx
config
.
exp_dir
=
exp_dir
# important so that the CO3D dataset gets loaded in full
config
.
dataset_args
.
test_on_train
=
False
# Set the rendering image size
config
.
generic_model_args
.
render_image_width
=
render_size
[
0
]
config
.
generic_model_args
.
render_image_height
=
render_size
[
1
]
if
restrict_sequence_name
is
not
None
:
config
.
dataset_args
.
restrict_sequence_name
=
restrict_sequence_name
# Set up the CUDA env for the visualization
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
config
.
gpu_idx
)
# Load the previously trained model
model
,
_
,
_
=
init_model
(
config
,
force_load
=
True
,
load_model_only
=
True
)
model
.
cuda
()
model
.
eval
()
# Setup the dataset
dataset
=
dataset_zoo
(
**
config
.
dataset_args
)[
split
]
# iterate over the sequences in the dataset
for
sequence_name
in
dataset
.
seq_to_idx
.
keys
():
with
torch
.
no_grad
():
render_sequence
(
dataset
,
sequence_name
,
model
,
video_path
=
"{}/video"
.
format
(
output_directory
),
n_source_views
=
n_source_views
,
visdom_show_preds
=
visdom_show_preds
,
n_eval_cameras
=
n_eval_cameras
,
visdom_server
=
visdom_server
,
visdom_port
=
visdom_port
,
viz_env
=
f
"visualizer_
{
config
.
visdom_env
}
"
if
visdom_env
is
None
else
visdom_env
,
video_resize
=
video_size
,
)
def
_get_config_from_experiment_directory
(
experiment_directory
):
cfg_file
=
os
.
path
.
join
(
experiment_directory
,
"expconfig.yaml"
)
config
=
OmegaConf
.
load
(
cfg_file
)
return
config
def
main
(
argv
):
# automatically parses arguments of export_scenes
cfg
=
OmegaConf
.
create
(
get_default_args
(
export_scenes
))
cfg
.
update
(
OmegaConf
.
from_cli
())
with
torch
.
no_grad
():
export_scenes
(
**
cfg
)
if
__name__
==
"__main__"
:
main
(
sys
.
argv
)
pytorch3d/implicitron/dataset/dataloader_zoo.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Dict
,
Sequence
import
torch
from
.implicitron_dataset
import
FrameData
,
ImplicitronDatasetBase
from
.scene_batch_sampler
import
SceneBatchSampler
def
dataloader_zoo
(
datasets
:
Dict
[
str
,
ImplicitronDatasetBase
],
dataset_name
:
str
=
"co3d_singlesequence"
,
batch_size
:
int
=
1
,
num_workers
:
int
=
0
,
dataset_len
:
int
=
1000
,
dataset_len_val
:
int
=
1
,
images_per_seq_options
:
Sequence
[
int
]
=
(
2
,),
sample_consecutive_frames
:
bool
=
False
,
consecutive_frames_max_gap
:
int
=
0
,
consecutive_frames_max_gap_seconds
:
float
=
0.1
,
)
->
Dict
[
str
,
torch
.
utils
.
data
.
DataLoader
]:
"""
Returns a set of dataloaders for a given set of datasets.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
if
dataset_name
not
in
[
"co3d_singlesequence"
,
"co3d_multisequence"
]:
raise
ValueError
(
f
"Unsupported dataset:
{
dataset_name
}
"
)
dataloaders
=
{}
if
dataset_name
in
[
"co3d_singlesequence"
,
"co3d_multisequence"
]:
for
dataset_set
,
dataset
in
datasets
.
items
():
num_samples
=
{
"train"
:
dataset_len
,
"val"
:
dataset_len_val
,
"test"
:
None
,
}[
dataset_set
]
if
dataset_set
==
"test"
:
batch_sampler
=
dataset
.
get_eval_batches
()
else
:
assert
num_samples
is
not
None
num_samples
=
len
(
dataset
)
if
num_samples
<=
0
else
num_samples
batch_sampler
=
SceneBatchSampler
(
dataset
,
batch_size
,
num_batches
=
num_samples
,
images_per_seq_options
=
images_per_seq_options
,
sample_consecutive_frames
=
sample_consecutive_frames
,
consecutive_frames_max_gap
=
consecutive_frames_max_gap
,
)
dataloaders
[
dataset_set
]
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
batch_sampler
=
batch_sampler
,
collate_fn
=
FrameData
.
collate
,
)
else
:
raise
ValueError
(
f
"Unsupported dataset:
{
dataset_name
}
"
)
return
dataloaders
pytorch3d/implicitron/dataset/dataset_zoo.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
json
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
from
iopath.common.file_io
import
PathManager
from
.implicitron_dataset
import
ImplicitronDataset
,
ImplicitronDatasetBase
from
.utils
import
(
DATASET_TYPE_KNOWN
,
DATASET_TYPE_TEST
,
DATASET_TYPE_TRAIN
,
DATASET_TYPE_UNKNOWN
,
)
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{
"default"
:
{
"box_crop"
:
True
,
"box_crop_context"
:
0.3
,
"image_width"
:
800
,
"image_height"
:
800
,
"remove_empty_masks"
:
True
,
}
}
# fmt: off
CO3D_CATEGORIES
:
List
[
str
]
=
list
(
reversed
([
"baseballbat"
,
"banana"
,
"bicycle"
,
"microwave"
,
"tv"
,
"cellphone"
,
"toilet"
,
"hairdryer"
,
"couch"
,
"kite"
,
"pizza"
,
"umbrella"
,
"wineglass"
,
"laptop"
,
"hotdog"
,
"stopsign"
,
"frisbee"
,
"baseballglove"
,
"cup"
,
"parkingmeter"
,
"backpack"
,
"toyplane"
,
"toybus"
,
"handbag"
,
"chair"
,
"keyboard"
,
"car"
,
"motorcycle"
,
"carrot"
,
"bottle"
,
"sandwich"
,
"remote"
,
"bowl"
,
"skateboard"
,
"toaster"
,
"mouse"
,
"toytrain"
,
"book"
,
"toytruck"
,
"orange"
,
"broccoli"
,
"plant"
,
"teddybear"
,
"suitcase"
,
"bench"
,
"ball"
,
"cake"
,
"vase"
,
"hydrant"
,
"apple"
,
"donut"
,
]))
# fmt: on
_CO3D_DATASET_ROOT
:
str
=
os
.
getenv
(
"CO3D_DATASET_ROOT"
,
""
)
def
dataset_zoo
(
dataset_name
:
str
=
"co3d_singlesequence"
,
dataset_root
:
str
=
_CO3D_DATASET_ROOT
,
category
:
str
=
"DEFAULT"
,
limit_to
:
int
=
-
1
,
limit_sequences_to
:
int
=
-
1
,
n_frames_per_sequence
:
int
=
-
1
,
test_on_train
:
bool
=
False
,
load_point_clouds
:
bool
=
False
,
mask_images
:
bool
=
False
,
mask_depths
:
bool
=
False
,
restrict_sequence_name
:
Sequence
[
str
]
=
(),
test_restrict_sequence_id
:
int
=
-
1
,
assert_single_seq
:
bool
=
False
,
only_test_set
:
bool
=
False
,
aux_dataset_kwargs
:
dict
=
DATASET_CONFIGS
[
"default"
],
path_manager
:
Optional
[
PathManager
]
=
None
,
)
->
Dict
[
str
,
ImplicitronDatasetBase
]:
"""
Generates the training / validation and testing dataset objects.
Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence.
test_on_train: Construct validation and test datasets from
the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'.
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call.
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
datasets
=
{}
# TODO:
# - implement loading multiple categories
if
dataset_name
in
[
"co3d_singlesequence"
,
"co3d_multisequence"
]:
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping
=
_get_co3d_set_names_mapping
(
dataset_name
,
test_on_train
,
only_test_set
,
)
# load the evaluation batches
task
=
dataset_name
.
split
(
"_"
)[
-
1
]
batch_indices_path
=
os
.
path
.
join
(
dataset_root
,
category
,
f
"eval_batches_
{
task
}
.json"
,
)
if
not
os
.
path
.
isfile
(
batch_indices_path
):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise
ValueError
(
"Please specify a correct dataset_root folder."
)
with
open
(
batch_indices_path
,
"r"
)
as
f
:
eval_batch_index
=
json
.
load
(
f
)
if
task
==
"singlesequence"
:
assert
(
test_restrict_sequence_id
is
not
None
and
test_restrict_sequence_id
>=
0
),
(
"Please specify an integer id 'test_restrict_sequence_id'"
+
" of the sequence considered for 'singlesequence'"
+
" training and evaluation."
)
assert
len
(
restrict_sequence_name
)
==
0
,
(
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
# a sort-stable set() equivalent:
eval_batches_sequence_names
=
list
(
{
b
[
0
][
0
]:
None
for
b
in
eval_batch_index
}.
keys
()
)
eval_sequence_name
=
eval_batches_sequence_names
[
test_restrict_sequence_id
]
eval_batch_index
=
[
b
for
b
in
eval_batch_index
if
b
[
0
][
0
]
==
eval_sequence_name
]
# overwrite the restrict_sequence_name
restrict_sequence_name
=
[
eval_sequence_name
]
for
dataset
,
subsets
in
set_names_mapping
.
items
():
frame_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"frame_annotations.jgz"
)
assert
os
.
path
.
isfile
(
frame_file
)
sequence_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"sequence_annotations.jgz"
)
assert
os
.
path
.
isfile
(
sequence_file
)
subset_lists_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"set_lists.json"
)
assert
os
.
path
.
isfile
(
subset_lists_file
)
# TODO: maybe directly in param list
params
=
{
**
copy
.
deepcopy
(
aux_dataset_kwargs
),
"frame_annotations_file"
:
frame_file
,
"sequence_annotations_file"
:
sequence_file
,
"subset_lists_file"
:
subset_lists_file
,
"dataset_root"
:
dataset_root
,
"limit_to"
:
limit_to
,
"limit_sequences_to"
:
limit_sequences_to
,
"n_frames_per_sequence"
:
n_frames_per_sequence
if
dataset
==
"train"
else
-
1
,
"subsets"
:
subsets
,
"load_point_clouds"
:
load_point_clouds
,
"mask_images"
:
mask_images
,
"mask_depths"
:
mask_depths
,
"pick_sequence"
:
restrict_sequence_name
,
"path_manager"
:
path_manager
,
}
datasets
[
dataset
]
=
ImplicitronDataset
(
**
params
)
if
dataset
==
"test"
:
if
len
(
restrict_sequence_name
)
>
0
:
eval_batch_index
=
[
b
for
b
in
eval_batch_index
if
b
[
0
][
0
]
in
restrict_sequence_name
]
datasets
[
dataset
].
eval_batches
=
datasets
[
dataset
].
seq_frame_index_to_dataset_index
(
eval_batch_index
)
if
assert_single_seq
:
# check theres only one sequence in all datasets
assert
(
len
(
{
e
[
"frame_annotation"
].
sequence_name
for
dset
in
datasets
.
values
()
for
e
in
dset
.
frame_annots
}
)
<=
1
),
"Multiple sequences loaded but expected one"
else
:
raise
ValueError
(
f
"Unsupported dataset:
{
dataset_name
}
"
)
if
test_on_train
:
datasets
[
"val"
]
=
datasets
[
"train"
]
datasets
[
"test"
]
=
datasets
[
"train"
]
return
datasets
def
_get_co3d_set_names_mapping
(
dataset_name
:
str
,
test_on_train
:
bool
,
only_test
:
bool
,
)
->
Dict
[
str
,
List
[
str
]]:
"""
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
to the names of the corresponding subsets in the CO3D dataset
("test_known"/"test_unseen"/"train_known"/"train_unseen").
"""
single_seq
=
dataset_name
==
"co3d_singlesequence"
if
only_test
:
set_names_mapping
=
{}
else
:
set_names_mapping
=
{
"train"
:
[
(
DATASET_TYPE_TEST
if
single_seq
else
DATASET_TYPE_TRAIN
)
+
"_"
+
DATASET_TYPE_KNOWN
]
}
if
not
test_on_train
:
prefixes
=
[
DATASET_TYPE_TEST
]
if
not
single_seq
:
prefixes
.
append
(
DATASET_TYPE_TRAIN
)
set_names_mapping
.
update
(
{
dset
:
[
p
+
"_"
+
t
for
p
in
prefixes
for
t
in
[
DATASET_TYPE_KNOWN
,
DATASET_TYPE_UNKNOWN
]
]
for
dset
in
[
"val"
,
"test"
]
}
)
return
set_names_mapping
pytorch3d/implicitron/dataset/implicitron_dataset.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
gzip
import
hashlib
import
json
import
os
import
random
import
warnings
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
,
fields
from
itertools
import
islice
from
pathlib
import
Path
from
typing
import
(
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypedDict
,
Union
,
)
import
numpy
as
np
import
torch
from
iopath.common.file_io
import
PathManager
from
PIL
import
Image
from
pytorch3d.io
import
IO
from
pytorch3d.renderer.camera_utils
import
join_cameras_as_batch
from
pytorch3d.renderer.cameras
import
CamerasBase
,
PerspectiveCameras
from
pytorch3d.structures.pointclouds
import
Pointclouds
,
join_pointclouds_as_batch
from
.
import
types
@
dataclass
class
FrameData
:
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box capturing the object in the
format (x0, y0, width, height).
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number
:
Optional
[
torch
.
LongTensor
]
frame_timestamp
:
Optional
[
torch
.
Tensor
]
sequence_name
:
Union
[
str
,
List
[
str
]]
sequence_category
:
Union
[
str
,
List
[
str
]]
image_size_hw
:
Optional
[
torch
.
Tensor
]
=
None
image_path
:
Union
[
str
,
List
[
str
],
None
]
=
None
image_rgb
:
Optional
[
torch
.
Tensor
]
=
None
# masks out padding added due to cropping the square bit
mask_crop
:
Optional
[
torch
.
Tensor
]
=
None
depth_path
:
Union
[
str
,
List
[
str
],
None
]
=
None
depth_map
:
Optional
[
torch
.
Tensor
]
=
None
depth_mask
:
Optional
[
torch
.
Tensor
]
=
None
mask_path
:
Union
[
str
,
List
[
str
],
None
]
=
None
fg_probability
:
Optional
[
torch
.
Tensor
]
=
None
bbox_xywh
:
Optional
[
torch
.
Tensor
]
=
None
camera
:
Optional
[
PerspectiveCameras
]
=
None
camera_quality_score
:
Optional
[
torch
.
Tensor
]
=
None
point_cloud_quality_score
:
Optional
[
torch
.
Tensor
]
=
None
sequence_point_cloud_path
:
Union
[
str
,
List
[
str
],
None
]
=
None
sequence_point_cloud
:
Optional
[
Pointclouds
]
=
None
sequence_point_cloud_idx
:
Optional
[
torch
.
Tensor
]
=
None
frame_type
:
Union
[
str
,
List
[
str
],
None
]
=
None
# seen | unseen
meta
:
dict
=
field
(
default_factory
=
lambda
:
{})
def
to
(
self
,
*
args
,
**
kwargs
):
new_params
=
{}
for
f
in
fields
(
self
):
value
=
getattr
(
self
,
f
.
name
)
if
isinstance
(
value
,
(
torch
.
Tensor
,
Pointclouds
,
CamerasBase
)):
new_params
[
f
.
name
]
=
value
.
to
(
*
args
,
**
kwargs
)
else
:
new_params
[
f
.
name
]
=
value
return
type
(
self
)(
**
new_params
)
def
cpu
(
self
):
return
self
.
to
(
device
=
torch
.
device
(
"cpu"
))
def
cuda
(
self
):
return
self
.
to
(
device
=
torch
.
device
(
"cuda"
))
# the following functions make sure **frame_data can be passed to functions
def
keys
(
self
):
for
f
in
fields
(
self
):
yield
f
.
name
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
@
classmethod
def
collate
(
cls
,
batch
):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem
=
batch
[
0
]
if
isinstance
(
elem
,
cls
):
pointcloud_ids
=
[
id
(
el
.
sequence_point_cloud
)
for
el
in
batch
]
id_to_idx
=
defaultdict
(
list
)
for
i
,
pc_id
in
enumerate
(
pointcloud_ids
):
id_to_idx
[
pc_id
].
append
(
i
)
sequence_point_cloud
=
[]
sequence_point_cloud_idx
=
-
np
.
ones
((
len
(
batch
),))
for
i
,
ind
in
enumerate
(
id_to_idx
.
values
()):
sequence_point_cloud_idx
[
ind
]
=
i
sequence_point_cloud
.
append
(
batch
[
ind
[
0
]].
sequence_point_cloud
)
assert
(
sequence_point_cloud_idx
>=
0
).
all
()
override_fields
=
{
"sequence_point_cloud"
:
sequence_point_cloud
,
"sequence_point_cloud_idx"
:
sequence_point_cloud_idx
.
tolist
(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated
=
{}
for
f
in
fields
(
elem
):
list_values
=
override_fields
.
get
(
f
.
name
,
[
getattr
(
d
,
f
.
name
)
for
d
in
batch
]
)
collated
[
f
.
name
]
=
(
cls
.
collate
(
list_values
)
if
all
(
list_value
is
not
None
for
list_value
in
list_values
)
else
None
)
return
cls
(
**
collated
)
elif
isinstance
(
elem
,
Pointclouds
):
return
join_pointclouds_as_batch
(
batch
)
elif
isinstance
(
elem
,
CamerasBase
):
# TODO: don't store K; enforce working in NDC space
return
join_cameras_as_batch
(
batch
)
else
:
return
torch
.
utils
.
data
.
_utils
.
collate
.
default_collate
(
batch
)
@
dataclass
(
eq
=
False
)
class
ImplicitronDatasetBase
(
torch
.
utils
.
data
.
Dataset
[
FrameData
]):
"""
Base class to describe a dataset to be used with Implicitron.
The dataset is made up of frames, and the frames are grouped into sequences.
Each sequence has a name (a string).
(A sequence could be a video, or a set of images of one scene.)
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
"""
seq_to_idx
:
Dict
[
str
,
List
[
int
]]
=
field
(
init
=
False
)
def
__len__
(
self
)
->
int
:
raise
NotImplementedError
def
get_frame_numbers_and_timestamps
(
self
,
idxs
:
Sequence
[
int
]
)
->
List
[
Tuple
[
int
,
float
]]:
"""
If the sequences in the dataset are videos rather than
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idx: frame index in self
Returns:
tuple of
- frame index in video
- timestamp of frame in video
"""
raise
ValueError
(
"This dataset does not contain videos."
)
def
get_eval_batches
(
self
)
->
Optional
[
List
[
List
[
int
]]]:
return
None
class
FrameAnnotsEntry
(
TypedDict
):
subset
:
Optional
[
str
]
frame_annotation
:
types
.
FrameAnnotation
@
dataclass
(
eq
=
False
)
class
ImplicitronDataset
(
ImplicitronDatasetBase
):
"""
A class for the Common Objects in 3D (CO3D) dataset.
Args:
frame_annotations_file: A zipped json file containing metadata of the
frames in the dataset, serialized List[types.FrameAnnotation].
sequence_annotations_file: A zipped json file containing metadata of the
sequences in the dataset, serialized List[types.SequenceAnnotation].
subset_lists_file: A json file containing the lists of frames corresponding
corresponding to different subsets (e.g. train/val/test) of the dataset;
format: {subset: (sequence_name, frame_id, file_path)}.
subsets: Restrict frames/sequences only to the given list of subsets
as defined in subset_lists_file (see above).
limit_to: Limit the dataset to the first #limit_to frames (after other
filters have been applied).
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences (after other sequence filters have been
applied but before frame-based filters).
pick_sequence: A list of sequence names to restrict the dataset to.
exclude_sequence: A list of the names of the sequences to exclude.
limit_category_to: Restrict the dataset to the given list of categories.
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
depth values used for evaluation (the points consistent across views).
load_masks: Enable loading frame foreground masks.
load_point_clouds: Enable loading sequence-level point clouds.
max_points: Cap on the number of loaded points in the point cloud;
if reached, they are randomly sampled without replacement.
mask_images: Whether to mask the images with the loaded foreground masks;
0 value is used for background.
mask_depths: Whether to mask the depth maps with the loaded foreground
masks; 0 value is used for background.
image_height: The height of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
image_width: The width of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
box_crop: Enable cropping of the image around the bounding box inferred
from the foreground region of the loaded segmentation mask; masks
and depth maps are cropped accordingly; cameras are corrected.
box_crop_mask_thr: The threshold used to separate pixels into foreground
and background based on the foreground_probability mask; if no value
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask after thresholding (see box_crop_mask_thr).
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
frames in each sequences uniformly without replacement if it has
more frames than that; applied before other frame-level filters.
seed: The seed of the random generator sampling #n_frames_per_sequence
random frames per sequence.
sort_frames: Enable frame annotations sorting to group frames from the
same sequences together and order them by timestamps
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
"""
frame_annotations_type
:
ClassVar
[
Type
[
types
.
FrameAnnotation
]
]
=
types
.
FrameAnnotation
path_manager
:
Optional
[
PathManager
]
=
None
frame_annotations_file
:
str
=
""
sequence_annotations_file
:
str
=
""
subset_lists_file
:
str
=
""
subsets
:
Optional
[
List
[
str
]]
=
None
limit_to
:
int
=
0
limit_sequences_to
:
int
=
0
pick_sequence
:
Sequence
[
str
]
=
()
exclude_sequence
:
Sequence
[
str
]
=
()
limit_category_to
:
Sequence
[
int
]
=
()
dataset_root
:
str
=
""
load_images
:
bool
=
True
load_depths
:
bool
=
True
load_depth_masks
:
bool
=
True
load_masks
:
bool
=
True
load_point_clouds
:
bool
=
False
max_points
:
int
=
0
mask_images
:
bool
=
False
mask_depths
:
bool
=
False
image_height
:
Optional
[
int
]
=
256
image_width
:
Optional
[
int
]
=
256
box_crop
:
bool
=
False
box_crop_mask_thr
:
float
=
0.4
box_crop_context
:
float
=
1.0
remove_empty_masks
:
bool
=
False
n_frames_per_sequence
:
int
=
-
1
seed
:
int
=
0
sort_frames
:
bool
=
False
eval_batches
:
Optional
[
List
[
List
[
int
]]]
=
None
frame_annots
:
List
[
FrameAnnotsEntry
]
=
field
(
init
=
False
)
seq_annots
:
Dict
[
str
,
types
.
SequenceAnnotation
]
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`.
self
.
subset_to_image_path
=
None
self
.
_load_frames
()
self
.
_load_sequences
()
if
self
.
sort_frames
:
self
.
_sort_frames
()
self
.
_load_subset_lists
()
self
.
_filter_db
()
# also computes sequence indices
print
(
str
(
self
))
def
seq_frame_index_to_dataset_index
(
self
,
seq_frame_index
:
Union
[
List
[
List
[
Union
[
Tuple
[
str
,
int
,
str
],
Tuple
[
str
,
int
]]]],
],
)
->
List
[
List
[
int
]]:
"""
Obtain indices into the dataset object given a list of frames specified as
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
"""
# TODO: check the frame numbers are unique
_dataset_seq_frame_n_index
=
{
seq
:
{
self
.
frame_annots
[
idx
][
"frame_annotation"
].
frame_number
:
idx
for
idx
in
seq_idx
}
for
seq
,
seq_idx
in
self
.
seq_to_idx
.
items
()
}
def
_get_batch_idx
(
seq_name
,
frame_no
,
path
=
None
)
->
int
:
idx
=
_dataset_seq_frame_n_index
[
seq_name
][
frame_no
]
if
path
is
not
None
:
# Check that the loaded frame path is consistent
# with the one stored in self.frame_annots.
assert
os
.
path
.
normpath
(
self
.
frame_annots
[
idx
][
"frame_annotation"
].
image
.
path
)
==
os
.
path
.
normpath
(
path
),
f
"Inconsistent batch
{
seq_name
,
frame_no
,
path
}
."
return
idx
batches_idx
=
[[
_get_batch_idx
(
*
b
)
for
b
in
batch
]
for
batch
in
seq_frame_index
]
return
batches_idx
def
__str__
(
self
)
->
str
:
return
f
"ImplicitronDataset #frames=
{
len
(
self
.
frame_annots
)
}
"
def
__len__
(
self
)
->
int
:
return
len
(
self
.
frame_annots
)
def
_get_frame_type
(
self
,
entry
:
FrameAnnotsEntry
)
->
Optional
[
str
]:
return
entry
[
"subset"
]
def
__getitem__
(
self
,
index
)
->
FrameData
:
if
index
>=
len
(
self
.
frame_annots
):
raise
IndexError
(
f
"index
{
index
}
out of range
{
len
(
self
.
frame_annots
)
}
"
)
entry
=
self
.
frame_annots
[
index
][
"frame_annotation"
]
point_cloud
=
self
.
seq_annots
[
entry
.
sequence_name
].
point_cloud
frame_data
=
FrameData
(
frame_number
=
_safe_as_tensor
(
entry
.
frame_number
,
torch
.
long
),
frame_timestamp
=
_safe_as_tensor
(
entry
.
frame_timestamp
,
torch
.
float
),
sequence_name
=
entry
.
sequence_name
,
sequence_category
=
self
.
seq_annots
[
entry
.
sequence_name
].
category
,
camera_quality_score
=
_safe_as_tensor
(
self
.
seq_annots
[
entry
.
sequence_name
].
viewpoint_quality_score
,
torch
.
float
,
),
point_cloud_quality_score
=
_safe_as_tensor
(
point_cloud
.
quality_score
,
torch
.
float
)
if
point_cloud
is
not
None
else
None
,
)
# The rest of the fields are optional
frame_data
.
frame_type
=
self
.
_get_frame_type
(
self
.
frame_annots
[
index
])
(
frame_data
.
fg_probability
,
frame_data
.
mask_path
,
frame_data
.
bbox_xywh
,
clamp_bbox_xyxy
,
)
=
self
.
_load_crop_fg_probability
(
entry
)
scale
=
1.0
if
self
.
load_images
and
entry
.
image
is
not
None
:
# original image size
frame_data
.
image_size_hw
=
_safe_as_tensor
(
entry
.
image
.
size
,
torch
.
long
)
(
frame_data
.
image_rgb
,
frame_data
.
image_path
,
frame_data
.
mask_crop
,
scale
,
)
=
self
.
_load_crop_images
(
entry
,
frame_data
.
fg_probability
,
clamp_bbox_xyxy
)
if
self
.
load_depths
and
entry
.
depth
is
not
None
:
(
frame_data
.
depth_map
,
frame_data
.
depth_path
,
frame_data
.
depth_mask
,
)
=
self
.
_load_mask_depth
(
entry
,
clamp_bbox_xyxy
,
frame_data
.
fg_probability
)
if
entry
.
viewpoint
is
not
None
:
frame_data
.
camera
=
self
.
_get_pytorch3d_camera
(
entry
,
scale
,
clamp_bbox_xyxy
,
)
if
self
.
load_point_clouds
and
point_cloud
is
not
None
:
frame_data
.
sequence_point_cloud_path
=
pcl_path
=
os
.
path
.
join
(
self
.
dataset_root
,
point_cloud
.
path
)
frame_data
.
sequence_point_cloud
=
_load_pointcloud
(
self
.
_local_path
(
pcl_path
),
max_points
=
self
.
max_points
)
return
frame_data
def
_load_crop_fg_probability
(
self
,
entry
:
types
.
FrameAnnotation
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
str
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
]:
fg_probability
,
full_path
,
bbox_xywh
,
clamp_bbox_xyxy
=
(
None
,
None
,
None
,
None
,
)
if
(
self
.
load_masks
or
self
.
box_crop
)
and
entry
.
mask
is
not
None
:
full_path
=
os
.
path
.
join
(
self
.
dataset_root
,
entry
.
mask
.
path
)
mask
=
_load_mask
(
self
.
_local_path
(
full_path
))
if
mask
.
shape
[
-
2
:]
!=
entry
.
image
.
size
:
raise
ValueError
(
f
"bad mask size:
{
mask
.
shape
[
-
2
:]
}
vs
{
entry
.
image
.
size
}
!"
)
bbox_xywh
=
torch
.
tensor
(
_get_bbox_from_mask
(
mask
,
self
.
box_crop_mask_thr
))
if
self
.
box_crop
:
clamp_bbox_xyxy
=
_get_clamp_bbox
(
bbox_xywh
,
self
.
box_crop_context
)
mask
=
_crop_around_box
(
mask
,
clamp_bbox_xyxy
,
full_path
)
fg_probability
,
_
,
_
=
self
.
_resize_image
(
mask
,
mode
=
"nearest"
)
return
fg_probability
,
full_path
,
bbox_xywh
,
clamp_bbox_xyxy
def
_load_crop_images
(
self
,
entry
:
types
.
FrameAnnotation
,
fg_probability
:
Optional
[
torch
.
Tensor
],
clamp_bbox_xyxy
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
str
,
torch
.
Tensor
,
float
]:
assert
self
.
dataset_root
is
not
None
and
entry
.
image
is
not
None
path
=
os
.
path
.
join
(
self
.
dataset_root
,
entry
.
image
.
path
)
image_rgb
=
_load_image
(
self
.
_local_path
(
path
))
if
image_rgb
.
shape
[
-
2
:]
!=
entry
.
image
.
size
:
raise
ValueError
(
f
"bad image size:
{
image_rgb
.
shape
[
-
2
:]
}
vs
{
entry
.
image
.
size
}
!"
)
if
self
.
box_crop
:
assert
clamp_bbox_xyxy
is
not
None
image_rgb
=
_crop_around_box
(
image_rgb
,
clamp_bbox_xyxy
,
path
)
image_rgb
,
scale
,
mask_crop
=
self
.
_resize_image
(
image_rgb
)
if
self
.
mask_images
:
assert
fg_probability
is
not
None
image_rgb
*=
fg_probability
return
image_rgb
,
path
,
mask_crop
,
scale
def
_load_mask_depth
(
self
,
entry
:
types
.
FrameAnnotation
,
clamp_bbox_xyxy
:
Optional
[
torch
.
Tensor
],
fg_probability
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
str
,
torch
.
Tensor
]:
entry_depth
=
entry
.
depth
assert
entry_depth
is
not
None
path
=
os
.
path
.
join
(
self
.
dataset_root
,
entry_depth
.
path
)
depth_map
=
_load_depth
(
self
.
_local_path
(
path
),
entry_depth
.
scale_adjustment
)
if
self
.
box_crop
:
assert
clamp_bbox_xyxy
is
not
None
depth_bbox_xyxy
=
_rescale_bbox
(
clamp_bbox_xyxy
,
entry
.
image
.
size
,
depth_map
.
shape
[
-
2
:]
)
depth_map
=
_crop_around_box
(
depth_map
,
depth_bbox_xyxy
,
path
)
depth_map
,
_
,
_
=
self
.
_resize_image
(
depth_map
,
mode
=
"nearest"
)
if
self
.
mask_depths
:
assert
fg_probability
is
not
None
depth_map
*=
fg_probability
if
self
.
load_depth_masks
:
assert
entry_depth
.
mask_path
is
not
None
mask_path
=
os
.
path
.
join
(
self
.
dataset_root
,
entry_depth
.
mask_path
)
depth_mask
=
_load_depth_mask
(
self
.
_local_path
(
mask_path
))
if
self
.
box_crop
:
assert
clamp_bbox_xyxy
is
not
None
depth_mask_bbox_xyxy
=
_rescale_bbox
(
clamp_bbox_xyxy
,
entry
.
image
.
size
,
depth_mask
.
shape
[
-
2
:]
)
depth_mask
=
_crop_around_box
(
depth_mask
,
depth_mask_bbox_xyxy
,
mask_path
)
depth_mask
,
_
,
_
=
self
.
_resize_image
(
depth_mask
,
mode
=
"nearest"
)
else
:
depth_mask
=
torch
.
ones_like
(
depth_map
)
return
depth_map
,
path
,
depth_mask
def
_get_pytorch3d_camera
(
self
,
entry
:
types
.
FrameAnnotation
,
scale
:
float
,
clamp_bbox_xyxy
:
Optional
[
torch
.
Tensor
],
)
->
PerspectiveCameras
:
entry_viewpoint
=
entry
.
viewpoint
assert
entry_viewpoint
is
not
None
# principal point and focal length
principal_point
=
torch
.
tensor
(
entry_viewpoint
.
principal_point
,
dtype
=
torch
.
float
)
focal_length
=
torch
.
tensor
(
entry_viewpoint
.
focal_length
,
dtype
=
torch
.
float
)
half_image_size_wh_orig
=
(
torch
.
tensor
(
list
(
reversed
(
entry
.
image
.
size
)),
dtype
=
torch
.
float
)
/
2.0
)
# first, we convert from the dataset's NDC convention to pixels
format
=
entry_viewpoint
.
intrinsics_format
if
format
.
lower
()
==
"ndc_norm_image_bounds"
:
# this is e.g. currently used in CO3D for storing intrinsics
rescale
=
half_image_size_wh_orig
elif
format
.
lower
()
==
"ndc_isotropic"
:
rescale
=
half_image_size_wh_orig
.
min
()
else
:
raise
ValueError
(
f
"Unknown intrinsics format:
{
format
}
"
)
# principal point and focal length in pixels
principal_point_px
=
half_image_size_wh_orig
-
principal_point
*
rescale
focal_length_px
=
focal_length
*
rescale
if
self
.
box_crop
:
assert
clamp_bbox_xyxy
is
not
None
principal_point_px
-=
clamp_bbox_xyxy
[:
2
]
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if
self
.
image_height
is
None
or
self
.
image_width
is
None
:
out_size
=
list
(
reversed
(
entry
.
image
.
size
))
else
:
out_size
=
[
self
.
image_width
,
self
.
image_height
]
half_image_size_output
=
torch
.
tensor
(
out_size
,
dtype
=
torch
.
float
)
/
2.0
half_min_image_size_output
=
half_image_size_output
.
min
()
# rescaled principal point and focal length in ndc
principal_point
=
(
half_image_size_output
-
principal_point_px
*
scale
)
/
half_min_image_size_output
focal_length
=
focal_length_px
*
scale
/
half_min_image_size_output
return
PerspectiveCameras
(
focal_length
=
focal_length
[
None
],
principal_point
=
principal_point
[
None
],
R
=
torch
.
tensor
(
entry_viewpoint
.
R
,
dtype
=
torch
.
float
)[
None
],
T
=
torch
.
tensor
(
entry_viewpoint
.
T
,
dtype
=
torch
.
float
)[
None
],
)
def
_load_frames
(
self
)
->
None
:
print
(
f
"Loading Co3D frames from
{
self
.
frame_annotations_file
}
."
)
local_file
=
self
.
_local_path
(
self
.
frame_annotations_file
)
with
gzip
.
open
(
local_file
,
"rt"
,
encoding
=
"utf8"
)
as
zipfile
:
frame_annots_list
=
types
.
load_dataclass
(
zipfile
,
List
[
self
.
frame_annotations_type
]
)
if
not
frame_annots_list
:
raise
ValueError
(
"Empty dataset!"
)
self
.
frame_annots
=
[
FrameAnnotsEntry
(
frame_annotation
=
a
,
subset
=
None
)
for
a
in
frame_annots_list
]
def
_load_sequences
(
self
)
->
None
:
print
(
f
"Loading Co3D sequences from
{
self
.
sequence_annotations_file
}
."
)
local_file
=
self
.
_local_path
(
self
.
sequence_annotations_file
)
with
gzip
.
open
(
local_file
,
"rt"
,
encoding
=
"utf8"
)
as
zipfile
:
seq_annots
=
types
.
load_dataclass
(
zipfile
,
List
[
types
.
SequenceAnnotation
])
if
not
seq_annots
:
raise
ValueError
(
"Empty sequences file!"
)
self
.
seq_annots
=
{
entry
.
sequence_name
:
entry
for
entry
in
seq_annots
}
def
_load_subset_lists
(
self
)
->
None
:
print
(
f
"Loading Co3D subset lists from
{
self
.
subset_lists_file
}
."
)
if
not
self
.
subset_lists_file
:
return
with
open
(
self
.
_local_path
(
self
.
subset_lists_file
),
"r"
)
as
f
:
subset_to_seq_frame
=
json
.
load
(
f
)
frame_path_to_subset
=
{
path
:
subset
for
subset
,
frames
in
subset_to_seq_frame
.
items
()
for
_
,
_
,
path
in
frames
}
for
frame
in
self
.
frame_annots
:
frame
[
"subset"
]
=
frame_path_to_subset
.
get
(
frame
[
"frame_annotation"
].
image
.
path
,
None
)
if
frame
[
"subset"
]
is
None
:
warnings
.
warn
(
"Subset lists are given but don't include "
+
frame
[
"frame_annotation"
].
image
.
path
)
def
_sort_frames
(
self
)
->
None
:
# Sort frames to have them grouped by sequence, ordered by timestamp
self
.
frame_annots
=
sorted
(
self
.
frame_annots
,
key
=
lambda
f
:
(
f
[
"frame_annotation"
].
sequence_name
,
f
[
"frame_annotation"
].
frame_timestamp
or
0
,
),
)
def
_filter_db
(
self
)
->
None
:
if
self
.
remove_empty_masks
:
print
(
"Removing images with empty masks."
)
old_len
=
len
(
self
.
frame_annots
)
msg
=
"remove_empty_masks needs every MaskAnnotation.mass to be set."
def
positive_mass
(
frame_annot
:
types
.
FrameAnnotation
)
->
bool
:
mask
=
frame_annot
.
mask
if
mask
is
None
:
return
False
if
mask
.
mass
is
None
:
raise
ValueError
(
msg
)
return
mask
.
mass
>
1
self
.
frame_annots
=
[
frame
for
frame
in
self
.
frame_annots
if
positive_mass
(
frame
[
"frame_annotation"
])
]
print
(
"... filtered %d -> %d"
%
(
old_len
,
len
(
self
.
frame_annots
)))
# this has to be called after joining with categories!!
subsets
=
self
.
subsets
if
subsets
:
if
not
self
.
subset_lists_file
:
raise
ValueError
(
"Subset filter is on but subset_lists_file was not given"
)
print
(
f
"Limitting Co3D dataset to the '
{
subsets
}
' subsets."
)
# truncate the list of subsets to the valid one
self
.
frame_annots
=
[
entry
for
entry
in
self
.
frame_annots
if
entry
[
"subset"
]
in
subsets
]
if
len
(
self
.
frame_annots
)
==
0
:
raise
ValueError
(
f
"There are no frames in the '
{
subsets
}
' subsets!"
)
self
.
_invalidate_indexes
(
filter_seq_annots
=
True
)
if
len
(
self
.
limit_category_to
)
>
0
:
print
(
f
"Limitting dataset to categories:
{
self
.
limit_category_to
}
"
)
self
.
seq_annots
=
{
name
:
entry
for
name
,
entry
in
self
.
seq_annots
.
items
()
if
entry
.
category
in
self
.
limit_category_to
}
# sequence filters
for
prefix
in
(
"pick"
,
"exclude"
):
orig_len
=
len
(
self
.
seq_annots
)
attr
=
f
"
{
prefix
}
_sequence"
arr
=
getattr
(
self
,
attr
)
if
len
(
arr
)
>
0
:
print
(
f
"
{
attr
}
:
{
str
(
arr
)
}
"
)
self
.
seq_annots
=
{
name
:
entry
for
name
,
entry
in
self
.
seq_annots
.
items
()
if
(
name
in
arr
)
==
(
prefix
==
"pick"
)
}
print
(
"... filtered %d -> %d"
%
(
orig_len
,
len
(
self
.
seq_annots
)))
if
self
.
limit_sequences_to
>
0
:
self
.
seq_annots
=
dict
(
islice
(
self
.
seq_annots
.
items
(),
self
.
limit_sequences_to
)
)
# retain only frames from retained sequences
self
.
frame_annots
=
[
f
for
f
in
self
.
frame_annots
if
f
[
"frame_annotation"
].
sequence_name
in
self
.
seq_annots
]
self
.
_invalidate_indexes
()
if
self
.
n_frames_per_sequence
>
0
:
print
(
f
"Taking max
{
self
.
n_frames_per_sequence
}
per sequence."
)
keep_idx
=
[]
for
seq
,
seq_indices
in
self
.
seq_to_idx
.
items
():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed
=
_seq_name_to_seed
(
seq
)
+
self
.
seed
seq_idx_shuffled
=
random
.
Random
(
seed
).
sample
(
sorted
(
seq_indices
),
len
(
seq_indices
)
)
keep_idx
.
extend
(
seq_idx_shuffled
[:
self
.
n_frames_per_sequence
])
print
(
"... filtered %d -> %d"
%
(
len
(
self
.
frame_annots
),
len
(
keep_idx
)))
self
.
frame_annots
=
[
self
.
frame_annots
[
i
]
for
i
in
keep_idx
]
self
.
_invalidate_indexes
(
filter_seq_annots
=
False
)
# sequences are not decimated, so self.seq_annots is valid
if
self
.
limit_to
>
0
and
self
.
limit_to
<
len
(
self
.
frame_annots
):
print
(
"limit_to: filtered %d -> %d"
%
(
len
(
self
.
frame_annots
),
self
.
limit_to
)
)
self
.
frame_annots
=
self
.
frame_annots
[:
self
.
limit_to
]
self
.
_invalidate_indexes
(
filter_seq_annots
=
True
)
def
_invalidate_indexes
(
self
,
filter_seq_annots
:
bool
=
False
)
->
None
:
# update seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
self
.
_invalidate_seq_to_idx
()
if
filter_seq_annots
:
self
.
seq_annots
=
{
k
:
v
for
k
,
v
in
self
.
seq_annots
.
items
()
if
k
in
self
.
seq_to_idx
}
def
_invalidate_seq_to_idx
(
self
)
->
None
:
seq_to_idx
=
defaultdict
(
list
)
for
idx
,
entry
in
enumerate
(
self
.
frame_annots
):
seq_to_idx
[
entry
[
"frame_annotation"
].
sequence_name
].
append
(
idx
)
self
.
seq_to_idx
=
seq_to_idx
def
_resize_image
(
self
,
image
,
mode
=
"bilinear"
)
->
Tuple
[
torch
.
Tensor
,
float
,
torch
.
Tensor
]:
image_height
,
image_width
=
self
.
image_height
,
self
.
image_width
if
image_height
is
None
or
image_width
is
None
:
# skip the resizing
imre_
=
torch
.
from_numpy
(
image
)
return
imre_
,
1.0
,
torch
.
ones_like
(
imre_
[:
1
])
# takes numpy array, returns pytorch tensor
minscale
=
min
(
image_height
/
image
.
shape
[
-
2
],
image_width
/
image
.
shape
[
-
1
],
)
imre
=
torch
.
nn
.
functional
.
interpolate
(
torch
.
from_numpy
(
image
)[
None
],
# pyre-ignore[6]
scale_factor
=
minscale
,
mode
=
mode
,
align_corners
=
False
if
mode
==
"bilinear"
else
None
,
recompute_scale_factor
=
True
,
)[
0
]
imre_
=
torch
.
zeros
(
image
.
shape
[
0
],
self
.
image_height
,
self
.
image_width
)
imre_
[:,
0
:
imre
.
shape
[
1
],
0
:
imre
.
shape
[
2
]]
=
imre
mask
=
torch
.
zeros
(
1
,
self
.
image_height
,
self
.
image_width
)
mask
[:,
0
:
imre
.
shape
[
1
]
-
1
,
0
:
imre
.
shape
[
2
]
-
1
]
=
1.0
return
imre_
,
minscale
,
mask
def
_local_path
(
self
,
path
:
str
)
->
str
:
if
self
.
path_manager
is
None
:
return
path
return
self
.
path_manager
.
get_local_path
(
path
)
def
get_frame_numbers_and_timestamps
(
self
,
idxs
:
Sequence
[
int
]
)
->
List
[
Tuple
[
int
,
float
]]:
out
:
List
[
Tuple
[
int
,
float
]]
=
[]
for
idx
in
idxs
:
frame_annotation
=
self
.
frame_annots
[
idx
][
"frame_annotation"
]
out
.
append
(
(
frame_annotation
.
frame_number
,
frame_annotation
.
frame_timestamp
)
)
return
out
def
get_eval_batches
(
self
)
->
Optional
[
List
[
List
[
int
]]]:
return
self
.
eval_batches
def
_seq_name_to_seed
(
seq_name
)
->
int
:
return
int
(
hashlib
.
sha1
(
seq_name
.
encode
(
"utf-8"
)).
hexdigest
(),
16
)
def
_load_image
(
path
)
->
np
.
ndarray
:
with
Image
.
open
(
path
)
as
pil_im
:
im
=
np
.
array
(
pil_im
.
convert
(
"RGB"
))
im
=
im
.
transpose
((
2
,
0
,
1
))
im
=
im
.
astype
(
np
.
float32
)
/
255.0
return
im
def
_load_16big_png_depth
(
depth_png
)
->
np
.
ndarray
:
with
Image
.
open
(
depth_png
)
as
depth_pil
:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth
=
(
np
.
frombuffer
(
np
.
array
(
depth_pil
,
dtype
=
np
.
uint16
),
dtype
=
np
.
float16
)
.
astype
(
np
.
float32
)
.
reshape
((
depth_pil
.
size
[
1
],
depth_pil
.
size
[
0
]))
)
return
depth
def
_load_1bit_png_mask
(
file
:
str
)
->
np
.
ndarray
:
with
Image
.
open
(
file
)
as
pil_im
:
mask
=
(
np
.
array
(
pil_im
.
convert
(
"L"
))
>
0.0
).
astype
(
np
.
float32
)
return
mask
def
_load_depth_mask
(
path
)
->
np
.
ndarray
:
if
not
path
.
lower
().
endswith
(
".png"
):
raise
ValueError
(
'unsupported depth mask file name "%s"'
%
path
)
m
=
_load_1bit_png_mask
(
path
)
return
m
[
None
]
# fake feature channel
def
_load_depth
(
path
,
scale_adjustment
)
->
np
.
ndarray
:
if
not
path
.
lower
().
endswith
(
".png"
):
raise
ValueError
(
'unsupported depth file name "%s"'
%
path
)
d
=
_load_16big_png_depth
(
path
)
*
scale_adjustment
d
[
~
np
.
isfinite
(
d
)]
=
0.0
return
d
[
None
]
# fake feature channel
def
_load_mask
(
path
)
->
np
.
ndarray
:
with
Image
.
open
(
path
)
as
pil_im
:
mask
=
np
.
array
(
pil_im
)
mask
=
mask
.
astype
(
np
.
float32
)
/
255.0
return
mask
[
None
]
# fake feature channel
def
_get_1d_bounds
(
arr
)
->
Tuple
[
int
,
int
]:
nz
=
np
.
flatnonzero
(
arr
)
return
nz
[
0
],
nz
[
-
1
]
def
_get_bbox_from_mask
(
mask
,
thr
,
decrease_quant
:
float
=
0.05
)
->
Tuple
[
int
,
int
,
int
,
int
]:
# bbox in xywh
masks_for_box
=
np
.
zeros_like
(
mask
)
while
masks_for_box
.
sum
()
<=
1.0
:
masks_for_box
=
(
mask
>
thr
).
astype
(
np
.
float32
)
thr
-=
decrease_quant
if
thr
<=
0.0
:
warnings
.
warn
(
f
"Empty masks_for_bbox (thr=
{
thr
}
) => using full image."
)
x0
,
x1
=
_get_1d_bounds
(
masks_for_box
.
sum
(
axis
=-
2
))
y0
,
y1
=
_get_1d_bounds
(
masks_for_box
.
sum
(
axis
=-
1
))
return
x0
,
y0
,
x1
-
x0
,
y1
-
y0
def
_get_clamp_bbox
(
bbox
:
torch
.
Tensor
,
box_crop_context
:
float
=
0.0
,
impath
:
str
=
""
)
->
torch
.
Tensor
:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
# increase box size
if
box_crop_context
>
0.0
:
c
=
box_crop_context
bbox
=
bbox
.
float
()
bbox
[
0
]
-=
bbox
[
2
]
*
c
/
2
bbox
[
1
]
-=
bbox
[
3
]
*
c
/
2
bbox
[
2
]
+=
bbox
[
2
]
*
c
bbox
[
3
]
+=
bbox
[
3
]
*
c
if
(
bbox
[
2
:]
<=
1.0
).
any
():
raise
ValueError
(
f
"squashed image
{
impath
}
!! The bounding box contains no pixels."
)
bbox
[
2
:]
=
torch
.
clamp
(
bbox
[
2
:],
2
)
bbox
[
2
:]
+=
bbox
[
0
:
2
]
+
1
# convert to [xmin, ymin, xmax, ymax]
# +1 because upper bound is not inclusive
return
bbox
def
_crop_around_box
(
tensor
,
bbox
,
impath
:
str
=
""
):
# bbox is xyxy, where the upper bound is corrected with +1
bbox
[[
0
,
2
]]
=
torch
.
clamp
(
bbox
[[
0
,
2
]],
0.0
,
tensor
.
shape
[
-
1
])
bbox
[[
1
,
3
]]
=
torch
.
clamp
(
bbox
[[
1
,
3
]],
0.0
,
tensor
.
shape
[
-
2
])
bbox
=
bbox
.
round
().
long
()
tensor
=
tensor
[...,
bbox
[
1
]
:
bbox
[
3
],
bbox
[
0
]
:
bbox
[
2
]]
assert
all
(
c
>
0
for
c
in
tensor
.
shape
),
f
"squashed image
{
impath
}
"
return
tensor
def
_rescale_bbox
(
bbox
:
torch
.
Tensor
,
orig_res
,
new_res
)
->
torch
.
Tensor
:
assert
bbox
is
not
None
assert
np
.
prod
(
orig_res
)
>
1e-8
# average ratio of dimensions
rel_size
=
(
new_res
[
0
]
/
orig_res
[
0
]
+
new_res
[
1
]
/
orig_res
[
1
])
/
2.0
return
bbox
*
rel_size
def
_safe_as_tensor
(
data
,
dtype
):
if
data
is
None
:
return
None
return
torch
.
tensor
(
data
,
dtype
=
dtype
)
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@
functools
.
lru_cache
(
maxsize
=
256
)
def
_load_pointcloud
(
pcl_path
:
Union
[
str
,
Path
],
max_points
:
int
=
0
)
->
Pointclouds
:
pcl
=
IO
().
load_pointcloud
(
pcl_path
)
if
max_points
>
0
:
pcl
=
pcl
.
subsample
(
max_points
)
return
pcl
pytorch3d/implicitron/dataset/scene_batch_sampler.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Iterator
,
List
,
Sequence
,
Tuple
import
numpy
as
np
from
torch.utils.data.sampler
import
Sampler
from
.implicitron_dataset
import
ImplicitronDatasetBase
@
dataclass
(
eq
=
False
)
# TODO: do we need this if not init from config?
class
SceneBatchSampler
(
Sampler
[
List
[
int
]]):
"""
A class for sampling training batches with a controlled composition
of sequences.
"""
dataset
:
ImplicitronDatasetBase
batch_size
:
int
num_batches
:
int
# the sampler first samples a random element k from this list and then
# takes k random frames per sequence
images_per_seq_options
:
Sequence
[
int
]
# if True, will sample a contiguous interval of frames in the sequence
# it first finds the connected segments within the sequence of sufficient length,
# then samples a random pivot element among them and ideally uses it as a middle
# of the temporal window, shifting the borders where necessary.
# This strategy mitigates the bias against shorter segments and their boundaries.
sample_consecutive_frames
:
bool
=
False
# if a number > 0, then used to define the maximum difference in frame_number
# of neighbouring frames when forming connected segments; otherwise the whole
# sequence is considered a segment regardless of frame numbers
consecutive_frames_max_gap
:
int
=
0
# same but for timestamps if they are available
consecutive_frames_max_gap_seconds
:
float
=
0.1
seq_names
:
List
[
str
]
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
if
self
.
batch_size
<=
0
:
raise
ValueError
(
"batch_size should be a positive integral value, "
f
"but got batch_size=
{
self
.
batch_size
}
"
)
if
len
(
self
.
images_per_seq_options
)
<
1
:
raise
ValueError
(
"n_per_seq_posibilities list cannot be empty"
)
self
.
seq_names
=
list
(
self
.
dataset
.
seq_to_idx
.
keys
())
def
__len__
(
self
)
->
int
:
return
self
.
num_batches
def
__iter__
(
self
)
->
Iterator
[
List
[
int
]]:
for
batch_idx
in
range
(
len
(
self
)):
batch
=
self
.
_sample_batch
(
batch_idx
)
yield
batch
def
_sample_batch
(
self
,
batch_idx
)
->
List
[
int
]:
n_per_seq
=
np
.
random
.
choice
(
self
.
images_per_seq_options
)
n_seqs
=
-
(
-
self
.
batch_size
//
n_per_seq
)
# round up
chosen_seq
=
_capped_random_choice
(
self
.
seq_names
,
n_seqs
,
replace
=
False
)
if
self
.
sample_consecutive_frames
:
frame_idx
=
[]
for
seq
in
chosen_seq
:
segment_index
=
self
.
_build_segment_index
(
list
(
self
.
dataset
.
seq_to_idx
[
seq
]),
n_per_seq
)
segment
,
idx
=
segment_index
[
np
.
random
.
randint
(
len
(
segment_index
))]
if
len
(
segment
)
<=
n_per_seq
:
frame_idx
.
append
(
segment
)
else
:
start
=
np
.
clip
(
idx
-
n_per_seq
//
2
,
0
,
len
(
segment
)
-
n_per_seq
)
frame_idx
.
append
(
segment
[
start
:
start
+
n_per_seq
])
else
:
frame_idx
=
[
_capped_random_choice
(
self
.
dataset
.
seq_to_idx
[
seq
],
n_per_seq
,
replace
=
False
)
for
seq
in
chosen_seq
]
frame_idx
=
np
.
concatenate
(
frame_idx
)[:
self
.
batch_size
].
tolist
()
if
len
(
frame_idx
)
<
self
.
batch_size
:
warnings
.
warn
(
"Batch size smaller than self.batch_size!"
+
" (This is fine for experiments with a single scene and viewpooling)"
)
return
frame_idx
def
_build_segment_index
(
self
,
seq_frame_indices
:
List
[
int
],
size
:
int
)
->
List
[
Tuple
[
List
[
int
],
int
]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
belongs to index is the frame's index within that segment.
Segment references are repeated but the memory is shared.
"""
if
(
self
.
consecutive_frames_max_gap
>
0
or
self
.
consecutive_frames_max_gap_seconds
>
0.0
):
sequence_timestamps
=
_sort_frames_by_timestamps_then_numbers
(
seq_frame_indices
,
self
.
dataset
)
# TODO: use new API to access frame numbers / timestamps
segments
=
self
.
_split_to_segments
(
sequence_timestamps
)
segments
=
_cull_short_segments
(
segments
,
size
)
if
not
segments
:
raise
AssertionError
(
"Empty segments after culling"
)
else
:
segments
=
[
seq_frame_indices
]
# build an index of segment for random selection of a pivot frame
segment_index
=
[
(
segment
,
i
)
for
segment
in
segments
for
i
in
range
(
len
(
segment
))
]
return
segment_index
def
_split_to_segments
(
self
,
sequence_timestamps
:
List
[
Tuple
[
float
,
int
,
int
]]
)
->
List
[
List
[
int
]]:
if
(
self
.
consecutive_frames_max_gap
<=
0
and
self
.
consecutive_frames_max_gap_seconds
<=
0.0
):
raise
AssertionError
(
"This function is only needed for non-trivial max_gap"
)
segments
=
[]
last_no
=
-
self
.
consecutive_frames_max_gap
-
1
# will trigger a new segment
last_ts
=
-
self
.
consecutive_frames_max_gap_seconds
-
1.0
for
ts
,
no
,
idx
in
sequence_timestamps
:
if
ts
<=
0.0
and
no
<=
last_no
:
raise
AssertionError
(
"Frames are not ordered in seq_to_idx while timestamps are not given"
)
if
(
no
-
last_no
>
self
.
consecutive_frames_max_gap
>
0
or
ts
-
last_ts
>
self
.
consecutive_frames_max_gap_seconds
>
0.0
):
# new group
segments
.
append
([
idx
])
else
:
segments
[
-
1
].
append
(
idx
)
last_no
=
no
last_ts
=
ts
return
segments
def
_sort_frames_by_timestamps_then_numbers
(
seq_frame_indices
:
List
[
int
],
dataset
:
ImplicitronDatasetBase
)
->
List
[
Tuple
[
float
,
int
,
int
]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps
=
dataset
.
get_frame_numbers_and_timestamps
(
seq_frame_indices
)
return
sorted
(
[
(
timestamp
,
frame_no
,
idx
)
for
idx
,
(
frame_no
,
timestamp
)
in
zip
(
seq_frame_indices
,
nos_timestamps
)
]
)
def
_cull_short_segments
(
segments
:
List
[
List
[
int
]],
min_size
:
int
)
->
List
[
List
[
int
]]:
lengths
=
[(
len
(
segment
),
segment
)
for
segment
in
segments
]
max_len
,
longest_segment
=
max
(
lengths
)
if
max_len
<
min_size
:
return
[
longest_segment
]
return
[
segment
for
segment
in
segments
if
len
(
segment
)
>=
min_size
]
def
_capped_random_choice
(
x
,
size
,
replace
:
bool
=
True
):
"""
if replace==True
randomly chooses from x `size` elements without replacement if len(x)>size
else allows replacement and selects `size` elements again.
if replace==False
randomly chooses from x `min(len(x), size)` elements without replacement
"""
len_x
=
x
if
isinstance
(
x
,
int
)
else
len
(
x
)
if
replace
:
return
np
.
random
.
choice
(
x
,
size
=
size
,
replace
=
len_x
<
size
)
else
:
return
np
.
random
.
choice
(
x
,
size
=
min
(
size
,
len_x
),
replace
=
False
)
pytorch3d/implicitron/dataset/types.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
dataclasses
import
gzip
import
json
import
sys
from
dataclasses
import
MISSING
,
Field
,
dataclass
from
typing
import
IO
,
Any
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
cast
import
numpy
as
np
_X
=
TypeVar
(
"_X"
)
if
sys
.
version_info
>=
(
3
,
8
,
0
):
from
typing
import
get_args
,
get_origin
elif
sys
.
version_info
>=
(
3
,
7
,
0
):
def
get_origin
(
cls
):
return
getattr
(
cls
,
"__origin__"
,
None
)
def
get_args
(
cls
):
return
getattr
(
cls
,
"__args__"
,
None
)
else
:
raise
ImportError
(
"This module requires Python 3.7+"
)
TF3
=
Tuple
[
float
,
float
,
float
]
@
dataclass
class
ImageAnnotation
:
# path to jpg file, relative w.r.t. dataset_root
path
:
str
# H x W
size
:
Tuple
[
int
,
int
]
# TODO: rename size_hw?
@
dataclass
class
DepthAnnotation
:
# path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment`
path
:
str
# a factor to convert png values to actual depth: `depth = png * scale_adjustment`
scale_adjustment
:
float
# path to png file, relative w.r.t. dataset_root, storing binary `depth` mask
mask_path
:
Optional
[
str
]
@
dataclass
class
MaskAnnotation
:
# path to png file storing (Prob(fg | pixel) * 255)
path
:
str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass
:
Optional
[
float
]
=
None
@
dataclass
class
ViewpointAnnotation
:
# In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T
R
:
Tuple
[
TF3
,
TF3
,
TF3
]
T
:
TF3
focal_length
:
Tuple
[
float
,
float
]
principal_point
:
Tuple
[
float
,
float
]
intrinsics_format
:
str
=
"ndc_norm_image_bounds"
# Defines the co-ordinate system where focal_length and principal_point live.
# Possible values: ndc_isotropic | ndc_norm_image_bounds (default)
# ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries
# correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ
# ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has
# the range [-1, 1], and the longer one has the range [-s, s]; s >= 1,
# where s is the aspect ratio. The scale is same along x and y.
@
dataclass
class
FrameAnnotation
:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name
:
str
# 0-based, continuous frame number within sequence
frame_number
:
int
# timestamp in seconds from the video start
frame_timestamp
:
float
image
:
ImageAnnotation
depth
:
Optional
[
DepthAnnotation
]
=
None
mask
:
Optional
[
MaskAnnotation
]
=
None
viewpoint
:
Optional
[
ViewpointAnnotation
]
=
None
@
dataclass
class
PointCloudAnnotation
:
# path to ply file with points only, relative w.r.t. dataset_root
path
:
str
# the bigger the better
quality_score
:
float
n_points
:
Optional
[
int
]
@
dataclass
class
VideoAnnotation
:
# path to the original video file, relative w.r.t. dataset_root
path
:
str
# length of the video in seconds
length
:
float
@
dataclass
class
SequenceAnnotation
:
sequence_name
:
str
category
:
str
video
:
Optional
[
VideoAnnotation
]
=
None
point_cloud
:
Optional
[
PointCloudAnnotation
]
=
None
# the bigger the better
viewpoint_quality_score
:
Optional
[
float
]
=
None
def
dump_dataclass
(
obj
:
Any
,
f
:
IO
,
binary
:
bool
=
False
)
->
None
:
"""
Args:
f: Either a path to a file, or a file opened for writing.
obj: A @dataclass or collection hierarchy including dataclasses.
binary: Set to True if `f` is a file handle, else False.
"""
if
binary
:
f
.
write
(
json
.
dumps
(
_asdict_rec
(
obj
)).
encode
(
"utf8"
))
else
:
json
.
dump
(
_asdict_rec
(
obj
),
f
)
def
load_dataclass
(
f
:
IO
,
cls
:
Type
[
_X
],
binary
:
bool
=
False
)
->
_X
:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if
binary
:
asdict
=
json
.
loads
(
f
.
read
().
decode
(
"utf8"
))
else
:
asdict
=
json
.
load
(
f
)
if
isinstance
(
asdict
,
list
):
# in the list case, run a faster "vectorized" version
cls
=
get_args
(
cls
)[
0
]
res
=
list
(
_dataclass_list_from_dict_list
(
asdict
,
cls
))
else
:
res
=
_dataclass_from_dict
(
asdict
,
cls
)
return
res
def
_dataclass_list_from_dict_list
(
dlist
,
typeannot
):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls
=
get_origin
(
typeannot
)
or
typeannot
if
all
(
obj
is
None
for
obj
in
dlist
):
# 1st recursion base: all None nodes
return
dlist
elif
any
(
obj
is
None
for
obj
in
dlist
):
# filter out Nones and recurse on the resulting list
idx_notnone
=
[(
i
,
obj
)
for
i
,
obj
in
enumerate
(
dlist
)
if
obj
is
not
None
]
idx
,
notnone
=
zip
(
*
idx_notnone
)
converted
=
_dataclass_list_from_dict_list
(
notnone
,
typeannot
)
res
=
[
None
]
*
len
(
dlist
)
for
i
,
obj
in
zip
(
idx
,
converted
):
res
[
i
]
=
obj
return
res
# otherwise, we dispatch by the type of the provided annotation to convert to
elif
issubclass
(
cls
,
tuple
)
and
hasattr
(
cls
,
"_fields"
):
# namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types
=
cls
.
_field_types
.
values
()
dlist_T
=
zip
(
*
dlist
)
res_T
=
[
_dataclass_list_from_dict_list
(
key_list
,
tp
)
for
key_list
,
tp
in
zip
(
dlist_T
,
types
)
]
return
[
cls
(
*
converted_as_tuple
)
for
converted_as_tuple
in
zip
(
*
res_T
)]
elif
issubclass
(
cls
,
(
list
,
tuple
)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types
=
get_args
(
typeannot
)
if
len
(
types
)
==
1
:
# probably List; replicate for all items
types
=
types
*
len
(
dlist
[
0
])
dlist_T
=
zip
(
*
dlist
)
res_T
=
(
_dataclass_list_from_dict_list
(
pos_list
,
tp
)
for
pos_list
,
tp
in
zip
(
dlist_T
,
types
)
)
if
issubclass
(
cls
,
tuple
):
return
list
(
zip
(
*
res_T
))
else
:
return
[
cls
(
converted_as_tuple
)
for
converted_as_tuple
in
zip
(
*
res_T
)]
elif
issubclass
(
cls
,
dict
):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t
,
val_t
=
get_args
(
typeannot
)
all_keys_res
=
_dataclass_list_from_dict_list
(
[
k
for
obj
in
dlist
for
k
in
obj
.
keys
()],
key_t
)
all_vals_res
=
_dataclass_list_from_dict_list
(
[
k
for
obj
in
dlist
for
k
in
obj
.
values
()],
val_t
)
indices
=
np
.
cumsum
([
len
(
obj
)
for
obj
in
dlist
])
assert
indices
[
-
1
]
==
len
(
all_keys_res
)
keys
=
np
.
split
(
list
(
all_keys_res
),
indices
[:
-
1
])
vals
=
np
.
split
(
list
(
all_vals_res
),
indices
[:
-
1
])
return
[
cls
(
zip
(
*
k
,
v
))
for
k
,
v
in
zip
(
keys
,
vals
)]
elif
not
dataclasses
.
is_dataclass
(
typeannot
):
return
dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert
dataclasses
.
is_dataclass
(
cls
)
fieldtypes
=
{
f
.
name
:
(
_unwrap_type
(
f
.
type
),
_get_dataclass_field_default
(
f
))
for
f
in
dataclasses
.
fields
(
typeannot
)
}
# NOTE the default object is shared here
key_lists
=
(
_dataclass_list_from_dict_list
([
obj
.
get
(
k
,
default
)
for
obj
in
dlist
],
type_
)
for
k
,
(
type_
,
default
)
in
fieldtypes
.
items
()
)
transposed
=
zip
(
*
key_lists
)
return
[
cls
(
*
vals_as_tuple
)
for
vals_as_tuple
in
transposed
]
def
_dataclass_from_dict
(
d
,
typeannot
):
cls
=
get_origin
(
typeannot
)
or
typeannot
if
d
is
None
:
return
d
elif
issubclass
(
cls
,
tuple
)
and
hasattr
(
cls
,
"_fields"
):
# namedtuple
types
=
cls
.
_field_types
.
values
()
return
cls
(
*
[
_dataclass_from_dict
(
v
,
tp
)
for
v
,
tp
in
zip
(
d
,
types
)])
elif
issubclass
(
cls
,
(
list
,
tuple
)):
types
=
get_args
(
typeannot
)
if
len
(
types
)
==
1
:
# probably List; replicate for all items
types
=
types
*
len
(
d
)
return
cls
(
_dataclass_from_dict
(
v
,
tp
)
for
v
,
tp
in
zip
(
d
,
types
))
elif
issubclass
(
cls
,
dict
):
key_t
,
val_t
=
get_args
(
typeannot
)
return
cls
(
(
_dataclass_from_dict
(
k
,
key_t
),
_dataclass_from_dict
(
v
,
val_t
))
for
k
,
v
in
d
.
items
()
)
elif
not
dataclasses
.
is_dataclass
(
typeannot
):
return
d
assert
dataclasses
.
is_dataclass
(
cls
)
fieldtypes
=
{
f
.
name
:
_unwrap_type
(
f
.
type
)
for
f
in
dataclasses
.
fields
(
typeannot
)}
return
cls
(
**
{
k
:
_dataclass_from_dict
(
v
,
fieldtypes
[
k
])
for
k
,
v
in
d
.
items
()})
def
_unwrap_type
(
tp
):
# strips Optional wrapper, if any
if
get_origin
(
tp
)
is
Union
:
args
=
get_args
(
tp
)
if
len
(
args
)
==
2
and
any
(
a
is
type
(
None
)
for
a
in
args
):
# noqa: E721
# this is typing.Optional
return
args
[
0
]
if
args
[
1
]
is
type
(
None
)
else
args
[
1
]
# noqa: E721
return
tp
def
_get_dataclass_field_default
(
field
:
Field
)
->
Any
:
if
field
.
default_factory
is
not
MISSING
:
return
field
.
default_factory
()
elif
field
.
default
is
not
MISSING
:
return
field
.
default
else
:
return
None
def
_asdict_rec
(
obj
):
return
dataclasses
.
_asdict_inner
(
obj
,
dict
)
def
dump_dataclass_jgzip
(
outfile
:
str
,
obj
:
Any
)
->
None
:
"""
Dumps obj to a gzipped json outfile.
Args:
obj: A @dataclass or collection hiererchy including dataclasses.
outfile: The path to the output file.
"""
with
gzip
.
GzipFile
(
outfile
,
"wb"
)
as
f
:
dump_dataclass
(
obj
,
cast
(
IO
,
f
),
binary
=
True
)
def
load_dataclass_jgzip
(
outfile
,
cls
):
"""
Loads a dataclass from a gzipped json outfile.
Args:
outfile: The path to the loaded file.
cls: The type annotation of the loaded dataclass.
Returns:
loaded_dataclass: The loaded dataclass.
"""
with
gzip
.
GzipFile
(
outfile
,
"rb"
)
as
f
:
return
load_dataclass
(
cast
(
IO
,
f
),
cls
,
binary
=
True
)
pytorch3d/implicitron/dataset/utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
,
Optional
import
torch
DATASET_TYPE_TRAIN
=
"train"
DATASET_TYPE_TEST
=
"test"
DATASET_TYPE_KNOWN
=
"known"
DATASET_TYPE_UNKNOWN
=
"unseen"
def
is_known_frame
(
frame_type
:
List
[
str
],
device
:
Optional
[
str
]
=
None
)
->
torch
.
BoolTensor
:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a known frame.
"""
return
torch
.
tensor
(
[
ft
.
endswith
(
DATASET_TYPE_KNOWN
)
for
ft
in
frame_type
],
dtype
=
torch
.
bool
,
device
=
device
,
)
def
is_train_frame
(
frame_type
:
List
[
str
],
device
:
Optional
[
str
]
=
None
)
->
torch
.
BoolTensor
:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a training frame.
"""
return
torch
.
tensor
(
[
ft
.
startswith
(
DATASET_TYPE_TRAIN
)
for
ft
in
frame_type
],
dtype
=
torch
.
bool
,
device
=
device
,
)
pytorch3d/implicitron/dataset/visualize.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
,
Tuple
,
cast
import
torch
from
pytorch3d.implicitron.tools.point_cloud_utils
import
get_rgbd_point_cloud
from
pytorch3d.structures
import
Pointclouds
from
.implicitron_dataset
import
FrameData
,
ImplicitronDataset
def
get_implicitron_sequence_pointcloud
(
dataset
:
ImplicitronDataset
,
sequence_name
:
Optional
[
str
]
=
None
,
mask_points
:
bool
=
True
,
max_frames
:
int
=
-
1
,
num_workers
:
int
=
0
,
load_dataset_point_cloud
:
bool
=
False
,
)
->
Tuple
[
Pointclouds
,
FrameData
]:
"""
Make a point cloud by sampling random points from each frame the dataset.
"""
if
len
(
dataset
)
==
0
:
raise
ValueError
(
"The dataset is empty."
)
if
not
dataset
.
load_depths
:
raise
ValueError
(
"The dataset has to load depths (dataset.load_depths=True)."
)
if
mask_points
and
not
dataset
.
load_masks
:
raise
ValueError
(
"For mask_points=True, the dataset has to load masks"
+
" (dataset.load_masks=True)."
)
# setup the indices of frames loaded from the dataset db
sequence_entries
=
list
(
range
(
len
(
dataset
)))
if
sequence_name
is
not
None
:
sequence_entries
=
[
ei
for
ei
in
sequence_entries
if
dataset
.
frame_annots
[
ei
][
"frame_annotation"
].
sequence_name
==
sequence_name
]
if
len
(
sequence_entries
)
==
0
:
raise
ValueError
(
f
'There are no dataset entries for sequence name "
{
sequence_name
}
".'
)
# subsample loaded frames if needed
if
(
max_frames
>
0
)
and
(
len
(
sequence_entries
)
>
max_frames
):
sequence_entries
=
[
sequence_entries
[
i
]
for
i
in
torch
.
randperm
(
len
(
sequence_entries
))[:
max_frames
].
sort
().
values
]
# take only the part of the dataset corresponding to the sequence entries
sequence_dataset
=
torch
.
utils
.
data
.
Subset
(
dataset
,
sequence_entries
)
# load the required part of the dataset
loader
=
torch
.
utils
.
data
.
DataLoader
(
sequence_dataset
,
batch_size
=
len
(
sequence_dataset
),
shuffle
=
False
,
num_workers
=
num_workers
,
collate_fn
=
FrameData
.
collate
,
)
frame_data
=
next
(
iter
(
loader
))
# there's only one batch
# scene point cloud
if
load_dataset_point_cloud
:
if
not
dataset
.
load_point_clouds
:
raise
ValueError
(
"For load_dataset_point_cloud=True, the dataset has to"
+
" load point clouds (dataset.load_point_clouds=True)."
)
point_cloud
=
frame_data
.
sequence_point_cloud
else
:
point_cloud
=
get_rgbd_point_cloud
(
frame_data
.
camera
,
frame_data
.
image_rgb
,
frame_data
.
depth_map
,
(
cast
(
torch
.
Tensor
,
frame_data
.
fg_probability
)
>
0.5
).
float
()
if
frame_data
.
fg_probability
is
not
None
else
None
,
mask_points
=
mask_points
,
)
return
point_cloud
,
frame_data
pytorch3d/implicitron/eval_demo.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
dataclasses
import
os
from
typing
import
Optional
,
cast
import
lpips
import
torch
from
pytorch3d.implicitron.dataset.dataloader_zoo
import
dataloader_zoo
from
pytorch3d.implicitron.dataset.dataset_zoo
import
CO3D_CATEGORIES
,
dataset_zoo
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
(
FrameData
,
ImplicitronDataset
,
ImplicitronDatasetBase
,
)
from
pytorch3d.implicitron.dataset.utils
import
is_known_frame
from
pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis
import
(
aggregate_nvs_results
,
eval_batch
,
pretty_print_nvs_metrics
,
summarize_nvs_eval_results
,
)
from
pytorch3d.implicitron.models.model_dbir
import
ModelDBIR
from
pytorch3d.implicitron.tools.utils
import
dataclass_to_cuda_
from
tqdm
import
tqdm
def
main
()
->
None
:
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for multisequence/singlesequence tasks for several categories.
The evaluation is conducted on the same data as in [1] and, hence, the results
are directly comparable to the numbers reported in [1].
References:
[1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone,
P. Labatut, D. Novotny:
Common Objects in 3D: Large-Scale Learning
and Evaluation of Real-life 3D Category Reconstruction
"""
task_results
=
{}
for
task
in
(
"singlesequence"
,
"multisequence"
):
task_results
[
task
]
=
[]
for
category
in
CO3D_CATEGORIES
[:
(
20
if
task
==
"singlesequence"
else
10
)]:
for
single_sequence_id
in
(
0
,
1
)
if
task
==
"singlesequence"
else
(
None
,):
category_result
=
evaluate_dbir_for_category
(
category
,
task
=
task
,
single_sequence_id
=
single_sequence_id
)
print
(
""
)
print
(
f
"Results for task=
{
task
}
; category=
{
category
}
;"
+
(
f
" sequence=
{
single_sequence_id
}
:"
if
single_sequence_id
is
not
None
else
":"
)
)
pretty_print_nvs_metrics
(
category_result
)
print
(
""
)
task_results
[
task
].
append
(
category_result
)
_print_aggregate_results
(
task
,
task_results
)
for
task
in
task_results
:
_print_aggregate_results
(
task
,
task_results
)
def
evaluate_dbir_for_category
(
category
:
str
=
"apple"
,
bg_color
:
float
=
0.0
,
task
:
str
=
"singlesequence"
,
single_sequence_id
:
Optional
[
int
]
=
None
,
num_workers
:
int
=
16
,
):
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for a given task, category, and sequence (in case task=='singlesequence').
Args:
category: Object category.
bg_color: Background color of the renders.
task: Evaluation task. Either singlesequence or multisequence.
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
num_workers: The number of workers for the employed dataloaders.
Returns:
category_result: A dictionary of quantitative metrics.
"""
single_sequence_id
=
single_sequence_id
if
single_sequence_id
is
not
None
else
-
1
torch
.
manual_seed
(
42
)
if
task
not
in
[
"multisequence"
,
"singlesequence"
]:
raise
ValueError
(
"'task' has to be either 'multisequence' or 'singlesequence'"
)
datasets
=
dataset_zoo
(
category
=
category
,
dataset_root
=
os
.
environ
[
"CO3D_DATASET_ROOT"
],
assert_single_seq
=
task
==
"singlesequence"
,
dataset_name
=
f
"co3d_
{
task
}
"
,
test_on_train
=
False
,
load_point_clouds
=
True
,
test_restrict_sequence_id
=
single_sequence_id
,
)
dataloaders
=
dataloader_zoo
(
datasets
,
dataset_name
=
f
"co3d_
{
task
}
"
,
)
test_dataset
=
datasets
[
"test"
]
test_dataloader
=
dataloaders
[
"test"
]
if
task
==
"singlesequence"
:
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
sequence_name
=
test_dataset
.
frame_annots
[
0
][
"frame_annotation"
].
sequence_name
all_source_cameras
=
_get_all_source_cameras
(
test_dataset
,
sequence_name
,
num_workers
=
num_workers
)
else
:
all_source_cameras
=
None
image_size
=
cast
(
ImplicitronDataset
,
test_dataset
).
image_width
if
image_size
is
None
:
raise
ValueError
(
"Image size should be set in the dataset"
)
# init the simple DBIR model
model
=
ModelDBIR
(
image_size
=
image_size
,
bg_color
=
bg_color
,
max_points
=
int
(
1e5
),
)
model
.
cuda
()
# init the lpips model for eval
lpips_model
=
lpips
.
LPIPS
(
net
=
"vgg"
)
lpips_model
=
lpips_model
.
cuda
()
per_batch_eval_results
=
[]
print
(
"Evaluating DBIR model ..."
)
for
frame_data
in
tqdm
(
test_dataloader
):
frame_data
=
dataclass_to_cuda_
(
frame_data
)
preds
=
model
(
**
dataclasses
.
asdict
(
frame_data
))
nvs_prediction
=
copy
.
deepcopy
(
preds
[
"nvs_prediction"
])
per_batch_eval_results
.
append
(
eval_batch
(
frame_data
,
nvs_prediction
,
bg_color
=
bg_color
,
lpips_model
=
lpips_model
,
source_cameras
=
all_source_cameras
,
)
)
category_result_flat
,
category_result
=
summarize_nvs_eval_results
(
per_batch_eval_results
,
task
)
return
category_result
[
"results"
]
def
_print_aggregate_results
(
task
,
task_results
)
->
None
:
"""
Prints the aggregate metrics for a given task.
"""
aggregate_task_result
=
aggregate_nvs_results
(
task_results
[
task
])
print
(
""
)
print
(
f
"Aggregate results for task=
{
task
}
:"
)
pretty_print_nvs_metrics
(
aggregate_task_result
)
print
(
""
)
def
_get_all_source_cameras
(
dataset
:
ImplicitronDatasetBase
,
sequence_name
:
str
,
num_workers
:
int
=
8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx
=
dataset
.
seq_to_idx
[
sequence_name
]
dataset_for_loader
=
torch
.
utils
.
data
.
Subset
(
dataset
,
seq_idx
)
(
all_frame_data
,)
=
torch
.
utils
.
data
.
DataLoader
(
dataset_for_loader
,
shuffle
=
False
,
batch_size
=
len
(
dataset_for_loader
),
num_workers
=
num_workers
,
collate_fn
=
FrameData
.
collate
,
)
is_known
=
is_known_frame
(
all_frame_data
.
frame_type
)
source_cameras
=
all_frame_data
.
camera
[
torch
.
where
(
is_known
)[
0
]]
return
source_cameras
if
__name__
==
"__main__"
:
main
()
pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
warnings
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
FrameData
from
pytorch3d.implicitron.dataset.utils
import
is_known_frame
,
is_train_frame
from
pytorch3d.implicitron.tools
import
vis_utils
from
pytorch3d.implicitron.tools.camera_utils
import
volumetric_camera_overlaps
from
pytorch3d.implicitron.tools.image_utils
import
mask_background
from
pytorch3d.implicitron.tools.metric_utils
import
calc_psnr
,
eval_depth
,
iou
,
rgb_l1
from
pytorch3d.implicitron.tools.point_cloud_utils
import
get_rgbd_point_cloud
from
pytorch3d.implicitron.tools.vis_utils
import
make_depth_image
from
pytorch3d.renderer.camera_utils
import
join_cameras_as_batch
from
pytorch3d.renderer.cameras
import
CamerasBase
,
PerspectiveCameras
from
pytorch3d.vis.plotly_vis
import
plot_scene
from
tabulate
import
tabulate
from
visdom
import
Visdom
EVAL_N_SRC_VIEWS
=
[
1
,
3
,
5
,
7
,
9
]
@
dataclass
class
NewViewSynthesisPrediction
:
"""
Holds the tensors that describe a result of synthesizing new views.
"""
depth_render
:
Optional
[
torch
.
Tensor
]
=
None
image_render
:
Optional
[
torch
.
Tensor
]
=
None
mask_render
:
Optional
[
torch
.
Tensor
]
=
None
camera_distance
:
Optional
[
torch
.
Tensor
]
=
None
@
dataclass
class
_Visualizer
:
image_render
:
torch
.
Tensor
image_rgb_masked
:
torch
.
Tensor
depth_render
:
torch
.
Tensor
depth_map
:
torch
.
Tensor
depth_mask
:
torch
.
Tensor
visdom_env
:
str
=
"eval_debug"
_viz
:
Visdom
=
field
(
init
=
False
)
def
__post_init__
(
self
):
self
.
_viz
=
vis_utils
.
get_visdom_connection
()
def
show_rgb
(
self
,
loss_value
:
float
,
metric_name
:
str
,
loss_mask_now
:
torch
.
Tensor
):
self
.
_viz
.
images
(
torch
.
cat
(
(
self
.
image_render
,
self
.
image_rgb_masked
,
loss_mask_now
.
repeat
(
1
,
3
,
1
,
1
),
),
dim
=
3
,
),
env
=
self
.
visdom_env
,
win
=
metric_name
,
opts
=
{
"title"
:
f
"
{
metric_name
}
_
{
loss_value
:
1.2
f
}
"
},
)
def
show_depth
(
self
,
depth_loss
:
float
,
name_postfix
:
str
,
loss_mask_now
:
torch
.
Tensor
):
self
.
_viz
.
images
(
torch
.
cat
(
(
make_depth_image
(
self
.
depth_render
,
loss_mask_now
),
make_depth_image
(
self
.
depth_map
,
loss_mask_now
),
),
dim
=
3
,
),
env
=
self
.
visdom_env
,
win
=
"depth_abs"
+
name_postfix
,
opts
=
{
"title"
:
f
"depth_abs_
{
name_postfix
}
_
{
depth_loss
:
1.2
f
}
"
},
)
self
.
_viz
.
images
(
loss_mask_now
,
env
=
self
.
visdom_env
,
win
=
"depth_abs"
+
name_postfix
+
"_mask"
,
opts
=
{
"title"
:
f
"depth_abs_
{
name_postfix
}
_
{
depth_loss
:
1.2
f
}
_mask"
},
)
self
.
_viz
.
images
(
self
.
depth_mask
,
env
=
self
.
visdom_env
,
win
=
"depth_abs"
+
name_postfix
+
"_maskd"
,
opts
=
{
"title"
:
f
"depth_abs_
{
name_postfix
}
_
{
depth_loss
:
1.2
f
}
_maskd"
},
)
# show the 3D plot
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
# `TensorProperties`.
viewpoint_trivial
:
PerspectiveCameras
=
PerspectiveCameras
().
to
(
loss_mask_now
.
device
)
pcl_pred
=
get_rgbd_point_cloud
(
viewpoint_trivial
,
self
.
image_render
,
self
.
depth_render
,
# mask_crop,
torch
.
ones_like
(
self
.
depth_render
),
# loss_mask_now,
)
pcl_gt
=
get_rgbd_point_cloud
(
viewpoint_trivial
,
self
.
image_rgb_masked
,
self
.
depth_map
,
# mask_crop,
torch
.
ones_like
(
self
.
depth_map
),
# loss_mask_now,
)
_pcls
=
{
pn
:
p
for
pn
,
p
in
zip
((
"pred_depth"
,
"gt_depth"
),
(
pcl_pred
,
pcl_gt
))
if
int
(
p
.
num_points_per_cloud
())
>
0
}
plotlyplot
=
plot_scene
(
{
f
"pcl
{
name_postfix
}
"
:
_pcls
},
camera_scale
=
1.0
,
pointcloud_max_points
=
10000
,
pointcloud_marker_size
=
1
,
)
self
.
_viz
.
plotlyplot
(
plotlyplot
,
env
=
self
.
visdom_env
,
win
=
f
"pcl
{
name_postfix
}
"
,
)
def
eval_batch
(
frame_data
:
FrameData
,
nvs_prediction
:
NewViewSynthesisPrediction
,
bg_color
:
Union
[
torch
.
Tensor
,
str
,
float
]
=
"black"
,
mask_thr
:
float
=
0.5
,
lpips_model
=
None
,
visualize
:
bool
=
False
,
visualize_visdom_env
:
str
=
"eval_debug"
,
break_after_visualising
:
bool
=
True
,
source_cameras
:
Optional
[
List
[
CamerasBase
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""
Produce performance metrics for a single batch of new-view synthesis
predictions.
Given a set of known views (for which frame_data.frame_type.endswith('known')
is True), a new-view synthesis method (NVS) is tasked to generate new views
of the scene from the viewpoint of the target views (for which
frame_data.frame_type.endswith('known') is False). The resulting
synthesized new views, stored in `nvs_prediction`, are compared to the
target ground truth in `frame_data` in terms of geometry and appearance
resulting in a dictionary of metrics returned by the `eval_batch` function.
Args:
frame_data: A FrameData object containing the input to the new view
synthesis method.
nvs_prediction: The data describing the synthesized new views.
bg_color: The background color of the generated new views and the
ground truth.
lpips_model: A pre-trained model for evaluating the LPIPS metric.
visualize: If True, visualizes the results to Visdom.
source_cameras: A list of all training cameras for evaluating the
difficulty of the target views.
Returns:
results: A dictionary holding evaluation metrics.
Throws:
ValueError if frame_data does not have frame_type, camera, or image_rgb
ValueError if the batch has a mix of training and test samples
ValueError if the batch frames are not [unseen, known, known, ...]
ValueError if one of the required fields in nvs_prediction is missing
"""
REQUIRED_NVS_PREDICTION_FIELDS
=
[
"mask_render"
,
"image_render"
,
"depth_render"
]
frame_type
=
frame_data
.
frame_type
if
frame_type
is
None
:
raise
ValueError
(
"Frame type has not been set."
)
# we check that all those fields are not None but Pyre can't infer that properly
# TODO: assign to local variables
if
frame_data
.
image_rgb
is
None
:
raise
ValueError
(
"Image is not in the evaluation batch."
)
if
frame_data
.
camera
is
None
:
raise
ValueError
(
"Camera is not in the evaluation batch."
)
if
any
(
not
hasattr
(
nvs_prediction
,
k
)
for
k
in
REQUIRED_NVS_PREDICTION_FIELDS
):
raise
ValueError
(
"One of the required predicted fields is missing"
)
# obtain copies to make sure we dont edit the original data
nvs_prediction
=
copy
.
deepcopy
(
nvs_prediction
)
frame_data
=
copy
.
deepcopy
(
frame_data
)
# mask the ground truth depth in case frame_data contains the depth mask
if
frame_data
.
depth_map
is
not
None
and
frame_data
.
depth_mask
is
not
None
:
frame_data
.
depth_map
*=
frame_data
.
depth_mask
if
not
isinstance
(
frame_type
,
list
):
# not batch FrameData
frame_type
=
[
frame_type
]
is_train
=
is_train_frame
(
frame_type
)
if
not
(
is_train
[
0
]
==
is_train
).
all
():
raise
ValueError
(
"All frames in the eval batch have to be either train/test."
)
# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known
=
is_known_frame
(
frame_type
,
device
=
frame_data
.
image_rgb
.
device
)
if
not
((
is_known
[
1
:]
==
1
).
all
()
and
(
is_known
[
0
]
==
0
).
all
()):
raise
ValueError
(
"For evaluation the first element of the batch has to be"
+
" a target view while the rest should be source views."
)
# TODO: do we need to enforce this?
# take only the first (target image)
for
k
in
REQUIRED_NVS_PREDICTION_FIELDS
:
setattr
(
nvs_prediction
,
k
,
getattr
(
nvs_prediction
,
k
)[:
1
])
for
k
in
[
"depth_map"
,
"image_rgb"
,
"fg_probability"
,
"mask_crop"
,
]:
if
not
hasattr
(
frame_data
,
k
)
or
getattr
(
frame_data
,
k
)
is
None
:
continue
setattr
(
frame_data
,
k
,
getattr
(
frame_data
,
k
)[:
1
])
if
frame_data
.
depth_map
is
None
or
frame_data
.
depth_map
.
sum
()
<=
0
:
warnings
.
warn
(
"Empty or missing depth map in evaluation!"
)
# eval all results in the resolution of the frame_data image
# pyre-fixme[16]: `Optional` has no attribute `shape`.
image_resol
=
list
(
frame_data
.
image_rgb
.
shape
[
2
:])
# threshold the masks to make ground truth binary masks
mask_fg
,
mask_crop
=
[
(
getattr
(
frame_data
,
k
)
>=
mask_thr
)
for
k
in
(
"fg_probability"
,
"mask_crop"
)
]
image_rgb_masked
=
mask_background
(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
frame_data
.
image_rgb
,
mask_fg
,
bg_color
=
bg_color
,
)
# resize to the target resolution
for
k
in
REQUIRED_NVS_PREDICTION_FIELDS
:
imode
=
"bilinear"
if
k
==
"image_render"
else
"nearest"
val
=
getattr
(
nvs_prediction
,
k
)
setattr
(
nvs_prediction
,
k
,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
torch
.
nn
.
functional
.
interpolate
(
val
,
size
=
image_resol
,
mode
=
imode
),
)
# clamp predicted images
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
image_render
=
nvs_prediction
.
image_render
.
clamp
(
0.0
,
1.0
)
if
visualize
:
visualizer
=
_Visualizer
(
image_render
=
image_render
,
image_rgb_masked
=
image_rgb_masked
,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
# `Optional[torch.Tensor]`.
depth_render
=
nvs_prediction
.
depth_render
,
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map
=
frame_data
.
depth_map
,
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
depth_mask
=
frame_data
.
depth_mask
[:
1
],
visdom_env
=
visualize_visdom_env
,
)
results
:
Dict
[
str
,
Any
]
=
{}
results
[
"iou"
]
=
iou
(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction
.
mask_render
,
mask_fg
,
mask
=
mask_crop
,
)
for
loss_fg_mask
,
name_postfix
in
zip
((
mask_crop
,
mask_fg
),
(
""
,
"_fg"
)):
loss_mask_now
=
mask_crop
*
loss_fg_mask
for
rgb_metric_name
,
rgb_metric_fun
in
zip
(
(
"psnr"
,
"rgb_l1"
),
(
calc_psnr
,
rgb_l1
)
):
metric_name
=
rgb_metric_name
+
name_postfix
results
[
metric_name
]
=
rgb_metric_fun
(
image_render
,
image_rgb_masked
,
mask
=
loss_mask_now
,
)
if
visualize
:
visualizer
.
show_rgb
(
results
[
metric_name
].
item
(),
metric_name
,
loss_mask_now
)
if
name_postfix
==
"_fg"
:
# only record depth metrics for the foreground
_
,
abs_
=
eval_depth
(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction
.
depth_render
,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
# `Optional[torch.Tensor]`.
frame_data
.
depth_map
,
get_best_scale
=
True
,
mask
=
loss_mask_now
,
crop
=
5
,
)
results
[
"depth_abs"
+
name_postfix
]
=
abs_
.
mean
()
if
visualize
:
visualizer
.
show_depth
(
abs_
.
mean
().
item
(),
name_postfix
,
loss_mask_now
)
if
break_after_visualising
:
import
pdb
pdb
.
set_trace
()
if
lpips_model
is
not
None
:
im1
,
im2
=
[
2.0
*
im
.
clamp
(
0.0
,
1.0
)
-
1.0
for
im
in
(
image_rgb_masked
,
nvs_prediction
.
image_render
)
]
results
[
"lpips"
]
=
lpips_model
.
forward
(
im1
,
im2
).
item
()
# convert all metrics to floats
results
=
{
k
:
float
(
v
)
for
k
,
v
in
results
.
items
()}
if
source_cameras
is
None
:
# pyre-fixme[16]: Optional has no attribute __getitem__
source_cameras
=
frame_data
.
camera
[
torch
.
where
(
is_known
)[
0
]]
results
[
"meta"
]
=
{
# calculate the camera difficulties and add to results
"camera_difficulty"
:
calculate_camera_difficulties
(
frame_data
.
camera
[
0
],
source_cameras
,
)[
0
].
item
(),
# store the size of the batch (corresponds to n_src_views+1)
"batch_size"
:
int
(
is_known
.
numel
()),
# store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type"
:
str
(
frame_data
.
frame_type
[
0
]),
}
return
results
def
average_per_batch_results
(
results_per_batch
:
List
[
Dict
[
str
,
Any
]],
idx
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
dict
:
"""
Average a list of per-batch metrics `results_per_batch`.
Optionally, if `idx` is given, only a subset of the per-batch
metrics, indexed by `idx`, is averaged.
"""
result_keys
=
list
(
results_per_batch
[
0
].
keys
())
result_keys
.
remove
(
"meta"
)
if
idx
is
not
None
:
results_per_batch
=
[
results_per_batch
[
i
]
for
i
in
idx
]
if
len
(
results_per_batch
)
==
0
:
return
{
k
:
float
(
"NaN"
)
for
k
in
result_keys
}
return
{
k
:
float
(
np
.
array
([
r
[
k
]
for
r
in
results_per_batch
]).
mean
())
for
k
in
result_keys
}
def
calculate_camera_difficulties
(
cameras_target
:
CamerasBase
,
cameras_source
:
CamerasBase
,
)
->
torch
.
Tensor
:
"""
Calculate the difficulties of the target cameras, given a set of known
cameras `cameras_source`.
Returns:
a tensor of shape (len(cameras_target),)
"""
ious
=
[
volumetric_camera_overlaps
(
join_cameras_as_batch
(
# pyre-fixme[6]: Expected `CamerasBase` for 1st param but got
# `Optional[pytorch3d.renderer.utils.TensorProperties]`.
[
cameras_target
[
cami
],
cameras_source
.
to
(
cameras_target
.
device
)]
)
)[
0
,
:]
for
cami
in
range
(
cameras_target
.
R
.
shape
[
0
])
]
camera_difficulties
=
torch
.
stack
(
[
_reduce_camera_iou_overlap
(
iou
[
1
:])
for
iou
in
ious
]
)
return
camera_difficulties
def
_reduce_camera_iou_overlap
(
ious
:
torch
.
Tensor
,
topk
:
int
=
2
)
->
torch
.
Tensor
:
"""
Calculate the final camera difficulty by computing the average of the
ious of the two most similar cameras.
Returns:
single-element Tensor
"""
# pyre-ignore[16] topk not recognized
return
ious
.
topk
(
k
=
min
(
topk
,
len
(
ious
)
-
1
)).
values
.
mean
()
def
get_camera_difficulty_bin_edges
(
task
:
str
):
"""
Get the edges of camera difficulty bins.
"""
_eps
=
1e-5
if
task
==
"multisequence"
:
# TODO: extract those to constants
diff_bin_edges
=
torch
.
linspace
(
0.5
,
1.0
+
_eps
,
4
)
diff_bin_edges
[
0
]
=
0.0
-
_eps
elif
task
==
"singlesequence"
:
diff_bin_edges
=
torch
.
tensor
([
0.0
-
_eps
,
0.97
,
0.98
,
1.0
+
_eps
]).
float
()
else
:
raise
ValueError
(
f
"No such eval task
{
task
}
."
)
diff_bin_names
=
[
"hard"
,
"medium"
,
"easy"
]
return
diff_bin_edges
,
diff_bin_names
def
summarize_nvs_eval_results
(
per_batch_eval_results
:
List
[
Dict
[
str
,
Any
]],
task
:
str
=
"singlesequence"
,
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task.
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
Either 'singlesequence' or 'multisequence'.
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
aux_out: A dictionary holding a set of auxiliary results.
"""
n_batches
=
len
(
per_batch_eval_results
)
eval_sets
:
List
[
Optional
[
str
]]
=
[]
if
task
==
"singlesequence"
:
eval_sets
=
[
None
]
# assert n_batches==100
elif
task
==
"multisequence"
:
eval_sets
=
[
"train"
,
"test"
]
# assert n_batches==1000
else
:
raise
ValueError
(
task
)
batch_sizes
=
torch
.
tensor
(
[
r
[
"meta"
][
"batch_size"
]
for
r
in
per_batch_eval_results
]
).
long
()
camera_difficulty
=
torch
.
tensor
(
[
r
[
"meta"
][
"camera_difficulty"
]
for
r
in
per_batch_eval_results
]
).
float
()
is_train
=
is_train_frame
([
r
[
"meta"
][
"frame_type"
]
for
r
in
per_batch_eval_results
])
# init the result database dict
results
=
[]
diff_bin_edges
,
diff_bin_names
=
get_camera_difficulty_bin_edges
(
task
)
n_diff_edges
=
diff_bin_edges
.
numel
()
# add per set averages
for
SET
in
eval_sets
:
if
SET
is
None
:
# task=='singlesequence'
ok_set
=
torch
.
ones
(
n_batches
,
dtype
=
torch
.
bool
)
set_name
=
"test"
else
:
# task=='multisequence'
ok_set
=
is_train
==
int
(
SET
==
"train"
)
set_name
=
SET
# eval each difficulty bin, including a full average result (diff_bin=None)
for
diff_bin
in
[
None
,
*
list
(
range
(
n_diff_edges
-
1
))]:
if
diff_bin
is
None
:
# average over all results
in_bin
=
ok_set
diff_bin_name
=
"all"
else
:
b1
,
b2
=
diff_bin_edges
[
diff_bin
:
(
diff_bin
+
2
)]
in_bin
=
ok_set
&
(
camera_difficulty
>
b1
)
&
(
camera_difficulty
<=
b2
)
diff_bin_name
=
diff_bin_names
[
diff_bin
]
bin_results
=
average_per_batch_results
(
per_batch_eval_results
,
idx
=
torch
.
where
(
in_bin
)[
0
]
)
results
.
append
(
{
"subset"
:
set_name
,
"subsubset"
:
f
"diff=
{
diff_bin_name
}
"
,
"metrics"
:
bin_results
,
}
)
if
task
==
"multisequence"
:
# split based on n_src_views
n_src_views
=
batch_sizes
-
1
for
n_src
in
EVAL_N_SRC_VIEWS
:
ok_src
=
ok_set
&
(
n_src_views
==
n_src
)
n_src_results
=
average_per_batch_results
(
per_batch_eval_results
,
idx
=
torch
.
where
(
ok_src
)[
0
],
)
results
.
append
(
{
"subset"
:
set_name
,
"subsubset"
:
f
"n_src=
{
int
(
n_src
)
}
"
,
"metrics"
:
n_src_results
,
}
)
aux_out
=
{
"results"
:
results
}
return
flatten_nvs_results
(
results
),
aux_out
def
_get_flat_nvs_metric_key
(
result
,
metric_name
)
->
str
:
metric_key_postfix
=
f
"|subset=
{
result
[
'subset'
]
}
|
{
result
[
'subsubset'
]
}
"
metric_key
=
f
"
{
metric_name
}{
metric_key_postfix
}
"
return
metric_key
def
flatten_nvs_results
(
results
):
"""
Takes input `results` list of dicts of the form:
```
[
{
'subset':'train/test/...',
'subsubset': 'src=1/src=2/...',
'metrics': nvs_eval_metrics}
},
...
]
```
And converts to a flat dict as follows:
{
'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics,
...
}
"""
results_flat
=
{}
for
result
in
results
:
for
metric_name
,
metric_val
in
result
[
"metrics"
].
items
():
metric_key
=
_get_flat_nvs_metric_key
(
result
,
metric_name
)
assert
metric_key
not
in
results_flat
results_flat
[
metric_key
]
=
metric_val
return
results_flat
def
pretty_print_nvs_metrics
(
results
)
->
None
:
subsets
,
subsubsets
=
[
_ordered_set
([
r
[
k
]
for
r
in
results
])
for
k
in
(
"subset"
,
"subsubset"
)
]
metrics
=
_ordered_set
([
metric
for
r
in
results
for
metric
in
r
[
"metrics"
]])
for
subset
in
subsets
:
tab
=
{}
for
metric
in
metrics
:
tab
[
metric
]
=
[]
header
=
[
"metric"
]
for
subsubset
in
subsubsets
:
metric_vals
=
[
r
[
"metrics"
][
metric
]
for
r
in
results
if
r
[
"subsubset"
]
==
subsubset
and
r
[
"subset"
]
==
subset
]
if
len
(
metric_vals
)
>
0
:
tab
[
metric
].
extend
(
metric_vals
)
header
.
extend
(
subsubsets
)
if
any
(
len
(
v
)
>
0
for
v
in
tab
.
values
()):
print
(
f
"===== NVS results; subset=
{
subset
}
====="
)
print
(
tabulate
(
[[
metric
,
*
v
]
for
metric
,
v
in
tab
.
items
()],
# pyre-fixme[61]: `header` is undefined, or not always defined.
headers
=
header
,
)
)
def
_ordered_set
(
list_
):
return
list
(
OrderedDict
((
i
,
0
)
for
i
in
list_
).
keys
())
def
aggregate_nvs_results
(
task_results
):
"""
Aggregate nvs results.
For singlescene, this averages over all categories and scenes,
for multiscene, the average is over all per-category results.
"""
task_results_cat
=
[
r_
for
r
in
task_results
for
r_
in
r
]
subsets
,
subsubsets
=
[
_ordered_set
([
r
[
k
]
for
r
in
task_results_cat
])
for
k
in
(
"subset"
,
"subsubset"
)
]
metrics
=
_ordered_set
(
[
metric
for
r
in
task_results_cat
for
metric
in
r
[
"metrics"
]]
)
average_results
=
[]
for
subset
in
subsets
:
for
subsubset
in
subsubsets
:
metrics_lists
=
[
r
[
"metrics"
]
for
r
in
task_results_cat
if
r
[
"subsubset"
]
==
subsubset
and
r
[
"subset"
]
==
subset
]
avg_metrics
=
{}
for
metric
in
metrics
:
avg_metrics
[
metric
]
=
float
(
np
.
nanmean
(
np
.
array
([
metric_list
[
metric
]
for
metric_list
in
metrics_lists
])
)
)
average_results
.
append
(
{
"subset"
:
subset
,
"subsubset"
:
subsubset
,
"metrics"
:
avg_metrics
,
}
)
return
average_results
pytorch3d/implicitron/models/autodecoder.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
warnings
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
pytorch3d.implicitron.tools.config
import
Configurable
# TODO: probabilistic embeddings?
class
Autodecoder
(
Configurable
,
torch
.
nn
.
Module
):
"""
Autodecoder module
Settings:
encoding_dim: Embedding dimension for the decoder.
n_instances: The maximum number of instances stored by the autodecoder.
init_scale: Scale factor for the initial autodecoder weights.
ignore_input: If `True`, optimizes a single code for any input.
"""
encoding_dim
:
int
=
0
n_instances
:
int
=
0
init_scale
:
float
=
1.0
ignore_input
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
if
self
.
n_instances
<=
0
:
# Do not init the codes at all in case we have 0 instances.
return
self
.
_autodecoder_codes
=
torch
.
nn
.
Embedding
(
self
.
n_instances
,
self
.
encoding_dim
,
scale_grad_by_freq
=
True
,
)
with
torch
.
no_grad
():
# weight has been initialised from Normal(0, 1)
self
.
_autodecoder_codes
.
weight
*=
self
.
init_scale
self
.
_sequence_map
=
self
.
_build_sequence_map
()
# Make sure to register hooks for correct handling of saving/loading
# the module's _sequence_map.
self
.
_register_load_state_dict_pre_hook
(
self
.
_load_sequence_map_hook
)
self
.
_register_state_dict_hook
(
_save_sequence_map_hook
)
def
_build_sequence_map
(
self
,
sequence_map_dict
:
Optional
[
Dict
[
str
,
int
]]
=
None
)
->
Dict
[
str
,
int
]:
"""
Args:
sequence_map_dict: A dictionary used to initialize the sequence_map.
Returns:
sequence_map: a dictionary of key: id pairs.
"""
# increments the counter when asked for a new value
sequence_map
=
defaultdict
(
iter
(
range
(
self
.
n_instances
)).
__next__
)
if
sequence_map_dict
is
not
None
:
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
# Since this is done in the original order, it should generate
# the same set of key:id pairs. We check this with an assert to be sure.
for
x
,
x_id
in
sequence_map_dict
.
items
():
x_id_
=
sequence_map
[
x
]
assert
x_id
==
x_id_
return
sequence_map
def
calc_squared_encoding_norm
(
self
):
if
self
.
n_instances
<=
0
:
return
None
return
(
self
.
_autodecoder_codes
.
weight
**
2
).
mean
()
def
get_encoding_dim
(
self
)
->
int
:
if
self
.
n_instances
<=
0
:
return
0
return
self
.
encoding_dim
def
forward
(
self
,
x
:
Union
[
torch
.
LongTensor
,
List
[
str
]])
->
Optional
[
torch
.
Tensor
]:
"""
Args:
x: A batch of `N` sequence identifiers. Either a long tensor of size
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
are hashed to codes (without collisions).
Returns:
codes: A tensor of shape `(N, self.encoding_dim)` containing the
sequence-specific autodecoder codes.
"""
if
self
.
n_instances
==
0
:
return
None
if
self
.
ignore_input
:
x
=
[
"singleton"
]
if
isinstance
(
x
[
0
],
str
):
try
:
x
=
torch
.
tensor
(
# pyre-ignore[29]
[
self
.
_sequence_map
[
elem
]
for
elem
in
x
],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
)
except
StopIteration
:
raise
ValueError
(
"Not enough n_instances in the autodecoder"
)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return
self
.
_autodecoder_codes
(
x
)
def
_load_sequence_map_hook
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
Returns:
Constructed sequence_map if it exists in the state_dict
else raises a warning only.
"""
sequence_map_key
=
prefix
+
"_sequence_map"
if
sequence_map_key
in
state_dict
:
sequence_map_dict
=
state_dict
.
pop
(
sequence_map_key
)
self
.
_sequence_map
=
self
.
_build_sequence_map
(
sequence_map_dict
=
sequence_map_dict
)
else
:
warnings
.
warn
(
"No sequence map in Autodecoder state dict!"
)
def
_save_sequence_map_hook
(
self
,
state_dict
,
prefix
,
local_metadata
,
)
->
None
:
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
"""
sequence_map_key
=
prefix
+
"_sequence_map"
sequence_map_dict
=
dict
(
self
.
_sequence_map
.
items
())
state_dict
[
sequence_map_key
]
=
sequence_map_dict
pytorch3d/implicitron/models/base.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
warnings
from
dataclasses
import
field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
tqdm
from
pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis
import
(
NewViewSynthesisPrediction
,
)
from
pytorch3d.implicitron.tools
import
image_utils
,
vis_utils
from
pytorch3d.implicitron.tools.config
import
Configurable
,
registry
,
run_auto_creation
from
pytorch3d.implicitron.tools.rasterize_mc
import
rasterize_mc_samples
from
pytorch3d.implicitron.tools.utils
import
cat_dataclass
from
pytorch3d.renderer
import
RayBundle
,
utils
as
rend_utils
from
pytorch3d.renderer.cameras
import
CamerasBase
from
visdom
import
Visdom
from
.autodecoder
import
Autodecoder
from
.implicit_function.base
import
ImplicitFunctionBase
from
.implicit_function.idr_feature_field
import
IdrFeatureField
# noqa
from
.implicit_function.neural_radiance_field
import
(
# noqa
NeRFormerImplicitFunction
,
NeuralRadianceFieldImplicitFunction
,
)
from
.implicit_function.scene_representation_networks
import
(
# noqa
SRNHyperNetImplicitFunction
,
SRNImplicitFunction
,
)
from
.metrics
import
ViewMetrics
from
.renderer.base
import
(
BaseRenderer
,
EvaluationMode
,
ImplicitFunctionWrapper
,
RendererOutput
,
RenderSamplingMode
,
)
from
.renderer.lstm_renderer
import
LSTMRenderer
# noqa
from
.renderer.multipass_ea
import
MultiPassEmissionAbsorptionRenderer
# noqa
from
.renderer.ray_sampler
import
RaySampler
from
.renderer.sdf_renderer
import
SignedDistanceFunctionRenderer
# noqa
from
.resnet_feature_extractor
import
ResNetFeatureExtractor
from
.view_pooling.feature_aggregation
import
FeatureAggregatorBase
from
.view_pooling.view_sampling
import
ViewSampler
STD_LOG_VARS
=
[
"objective"
,
"epoch"
,
"sec/it"
]
# pyre-ignore: 13
class
GenericModel
(
Configurable
,
torch
.
nn
.
Module
):
"""
GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists
of the following sequence of 7 steps (steps 2–4 are normally
skipped in overfitting scenario, since conditioning on source views
does not add much information; otherwise they should be present altogether):
(1) Ray Sampling
------------------
Rays are sampled from an image grid based on the target view(s).
│_____________
│ │
│ ▼
│ (2) Feature Extraction (optional)
│ -----------------------
│ A feature extractor (e.g. a convolutional
│ neural net) is used to extract image features
│ from the source view(s).
│ │
│ ▼
│ (3) View Sampling (optional)
│ ------------------
│ Image features are sampled at the 2D projections
│ of a set of 3D points along each of the sampled
│ target rays from (1).
│ │
│ ▼
│ (4) Feature Aggregation (optional)
│ ------------------
│ Aggregate features and masks sampled from
│ image view(s) in (3).
│ │
│____________▼
│
▼
(5) Implicit Function Evaluation
------------------
Evaluate the implicit function(s) at the sampled ray points
(optionally pass in the aggregated image features from (4)).
│
▼
(6) Rendering
------------------
Render the image into the target cameras by raymarching along
the sampled rays and aggregating the colors and densities
output by the implicit function in (5).
│
▼
(7) Loss Computation
------------------
Compute losses based on the predicted target image(s).
The `forward` function of GenericModel executes
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
can be customized by intializing a subclass of the appropriate
baseclass and adding the newly created module to the registry.
Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins
for more details on how to create and register a custom component.
In the config .yaml files for experiments, the parameters below are
contained in the `generic_model_args` node. As GenericModel
derives from Configurable, the input arguments are
parsed by the run_auto_creation function to initialize the
necessary member modules. Please see implicitron_trainer/README.md
for more details on this process.
Args:
mask_images: Whether or not to mask the RGB image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
mask_depths: Whether or not to mask the depth image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
render_image_width: Width of the output image to render
render_image_height: Height of the output image to render
mask_threshold: If greater than 0.0, the foreground mask is
thresholded by this value before being applied to the RGB/Depth images
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
splatting onto an image grid. Default: False.
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
view_pool: If True, features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
num_passes: The specified implicit_function is initialized num_passes
times and run sequentially.
chunk_size_grid: The total number of points which can be rendered
per chunk. This is used to compute the number of rays used
per chunk when the chunked version of the renderer is used (in order
to fit rendering on all rays in memory)
render_features_dimensions: The number of output features to render.
Defaults to 3, corresponding to RGB images.
n_train_target_views: The number of cameras to render into at training
time; first `n_train_target_views` in the batch are considered targets,
the rest are sources.
sampling_mode_training: The sampling method to use during training. Must be
a value from the RenderSamplingMode Enum.
sampling_mode_evaluation: Same as above but for evaluation.
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
of the image (referred to as the global_code) that can be used to model aspects of
the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and
concatenate to the other inputs for the implicit function.
raysampler: An instance of RaySampler which is used to emit
rays from the target view(s).
renderer_class_type: The name of the renderer class which is available in the global
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
implicit_function_class_type: The type of implicit function to use which
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
are initialised to be in self._implicit_functions.
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
for `ViewMetrics` class for available loss functions.
log_vars: A list of variable names which should be logged.
The names should correspond to a subset of the keys of the
dict `preds` output by the `forward` function.
"""
mask_images
:
bool
=
True
mask_depths
:
bool
=
True
render_image_width
:
int
=
400
render_image_height
:
int
=
400
mask_threshold
:
float
=
0.5
output_rasterized_mc
:
bool
=
False
bg_color
:
Tuple
[
float
,
float
,
float
]
=
(
0.0
,
0.0
,
0.0
)
view_pool
:
bool
=
False
num_passes
:
int
=
1
chunk_size_grid
:
int
=
4096
render_features_dimensions
:
int
=
3
tqdm_trigger_threshold
:
int
=
16
n_train_target_views
:
int
=
1
sampling_mode_training
:
str
=
"mask_sample"
sampling_mode_evaluation
:
str
=
"full_grid"
# ---- autodecoder settings
sequence_autodecoder
:
Autodecoder
# ---- raysampler
raysampler
:
RaySampler
# ---- renderer configs
renderer_class_type
:
str
=
"MultiPassEmissionAbsorptionRenderer"
renderer
:
BaseRenderer
# ---- view sampling settings - used if view_pool=True
# (This is only created if view_pool is False)
image_feature_extractor
:
ResNetFeatureExtractor
view_sampler
:
ViewSampler
# ---- ---- view sampling feature aggregator settings
feature_aggregator_class_type
:
str
=
"AngleWeightedReductionFeatureAggregator"
feature_aggregator
:
FeatureAggregatorBase
# ---- implicit function settings
implicit_function_class_type
:
str
=
"NeuralRadianceFieldImplicitFunction"
# This is just a model, never constructed.
# The actual implicit functions live in self._implicit_functions
implicit_function
:
ImplicitFunctionBase
# ---- loss weights
loss_weights
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
lambda
:
{
"loss_rgb_mse"
:
1.0
,
"loss_prev_stage_rgb_mse"
:
1.0
,
"loss_mask_bce"
:
0.0
,
"loss_prev_stage_mask_bce"
:
0.0
,
}
)
# ---- variables to be logged (logger automatically ignores if not computed)
log_vars
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[
"loss_rgb_psnr_fg"
,
"loss_rgb_psnr"
,
"loss_rgb_mse"
,
"loss_rgb_huber"
,
"loss_depth_abs"
,
"loss_depth_abs_fg"
,
"loss_mask_neg_iou"
,
"loss_mask_bce"
,
"loss_mask_beta_prior"
,
"loss_eikonal"
,
"loss_density_tv"
,
"loss_depth_neg_penalty"
,
"loss_autodecoder_norm"
,
# metrics that are only logged in 2+stage renderes
"loss_prev_stage_rgb_mse"
,
"loss_prev_stage_rgb_psnr_fg"
,
"loss_prev_stage_rgb_psnr"
,
"loss_prev_stage_mask_bce"
,
*
STD_LOG_VARS
,
]
)
def
__post_init__
(
self
):
super
().
__init__
()
self
.
view_metrics
=
ViewMetrics
()
self
.
_check_and_preprocess_renderer_configs
()
self
.
raysampler_args
[
"sampling_mode_training"
]
=
self
.
sampling_mode_training
self
.
raysampler_args
[
"sampling_mode_evaluation"
]
=
self
.
sampling_mode_evaluation
self
.
raysampler_args
[
"image_width"
]
=
self
.
render_image_width
self
.
raysampler_args
[
"image_height"
]
=
self
.
render_image_height
run_auto_creation
(
self
)
self
.
_implicit_functions
=
self
.
_construct_implicit_functions
()
self
.
print_loss_weights
()
def
forward
(
self
,
*
,
# force keyword-only arguments
image_rgb
:
Optional
[
torch
.
Tensor
],
camera
:
CamerasBase
,
fg_probability
:
Optional
[
torch
.
Tensor
],
mask_crop
:
Optional
[
torch
.
Tensor
],
depth_map
:
Optional
[
torch
.
Tensor
],
sequence_name
:
Optional
[
List
[
str
]],
evaluation_mode
:
EvaluationMode
=
EvaluationMode
.
EVALUATION
,
**
kwargs
,
)
->
Dict
[
str
,
Any
]:
"""
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
the first `min(B, n_train_target_views)` images are considered targets and
are used to supervise the renders; the rest corresponding to the source
viewpoints from which features will be extracted.
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
to the viewpoints of target images, from which the rays will be sampled,
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
Returns:
preds: A dictionary containing all outputs of the forward pass including the
rendered images, depths, masks, losses and other metrics.
"""
image_rgb
,
fg_probability
,
depth_map
=
self
.
_preprocess_input
(
image_rgb
,
fg_probability
,
depth_map
)
# Obtain the batch size from the camera as this is the only required input.
batch_size
=
camera
.
R
.
shape
[
0
]
# Determine the number of target views, i.e. cameras we render into.
n_targets
=
(
1
if
evaluation_mode
==
EvaluationMode
.
EVALUATION
else
batch_size
if
self
.
n_train_target_views
<=
0
else
min
(
self
.
n_train_target_views
,
batch_size
)
)
# Select the target cameras.
target_cameras
=
camera
[
list
(
range
(
n_targets
))]
# Determine the used ray sampling mode.
sampling_mode
=
RenderSamplingMode
(
self
.
sampling_mode_training
if
evaluation_mode
==
EvaluationMode
.
TRAINING
else
self
.
sampling_mode_evaluation
)
# (1) Sample rendering rays with the ray sampler.
ray_bundle
:
RayBundle
=
self
.
raysampler
(
target_cameras
,
evaluation_mode
,
mask
=
mask_crop
[:
n_targets
]
if
mask_crop
is
not
None
and
sampling_mode
==
RenderSamplingMode
.
MASK_SAMPLE
else
None
,
)
# custom_args hold additional arguments to the implicit function.
custom_args
=
{}
if
self
.
view_pool
:
if
sequence_name
is
None
:
raise
ValueError
(
"sequence_name must be provided for view pooling"
)
# (2) Extract features for the image
img_feats
=
self
.
image_feature_extractor
(
image_rgb
,
fg_probability
)
# (3) Sample features and masks at the ray points
curried_view_sampler
=
lambda
pts
:
self
.
view_sampler
(
# noqa: E731
pts
=
pts
,
seq_id_pts
=
sequence_name
[:
n_targets
],
camera
=
camera
,
seq_id_camera
=
sequence_name
,
feats
=
img_feats
,
masks
=
mask_crop
,
)
# returns feats_sampled, masks_sampled
# (4) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
curried_view_pool
=
lambda
pts
:
self
.
feature_aggregator
(
# noqa: E731
*
curried_view_sampler
(
pts
=
pts
),
pts
=
pts
,
camera
=
camera
,
)
# TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args
[
"fun_viewpool"
]
=
curried_view_pool
global_code
=
None
if
self
.
sequence_autodecoder
.
n_instances
>
0
:
if
sequence_name
is
None
:
raise
ValueError
(
"sequence_name must be provided for autodecoder."
)
global_code
=
self
.
sequence_autodecoder
(
sequence_name
[:
n_targets
])
custom_args
[
"global_code"
]
=
global_code
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for
func
in
self
.
_implicit_functions
:
func
.
bind_args
(
**
custom_args
)
object_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
fg_probability
is
not
None
:
sampled_fb_prob
=
rend_utils
.
ndc_grid_sample
(
fg_probability
[:
n_targets
],
ray_bundle
.
xys
,
mode
=
"nearest"
)
object_mask
=
sampled_fb_prob
>
0.5
# (5)-(6) Implicit function evaluation and Rendering
rendered
=
self
.
_render
(
ray_bundle
=
ray_bundle
,
sampling_mode
=
sampling_mode
,
evaluation_mode
=
evaluation_mode
,
implicit_functions
=
self
.
_implicit_functions
,
object_mask
=
object_mask
,
)
# Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the
# bound arguments.
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for
func
in
self
.
_implicit_functions
:
func
.
unbind_args
()
preds
=
self
.
_get_view_metrics
(
raymarched
=
rendered
,
xys
=
ray_bundle
.
xys
,
image_rgb
=
None
if
image_rgb
is
None
else
image_rgb
[:
n_targets
],
depth_map
=
None
if
depth_map
is
None
else
depth_map
[:
n_targets
],
fg_probability
=
None
if
fg_probability
is
None
else
fg_probability
[:
n_targets
],
mask_crop
=
None
if
mask_crop
is
None
else
mask_crop
[:
n_targets
],
)
if
sampling_mode
==
RenderSamplingMode
.
MASK_SAMPLE
:
if
self
.
output_rasterized_mc
:
# Visualize the monte-carlo pixel renders by splatting onto
# an image grid.
(
preds
[
"images_render"
],
preds
[
"depths_render"
],
preds
[
"masks_render"
],
)
=
self
.
_rasterize_mc_samples
(
ray_bundle
.
xys
,
rendered
.
features
,
rendered
.
depths
,
masks
=
rendered
.
masks
,
)
elif
sampling_mode
==
RenderSamplingMode
.
FULL_GRID
:
preds
[
"images_render"
]
=
rendered
.
features
.
permute
(
0
,
3
,
1
,
2
)
preds
[
"depths_render"
]
=
rendered
.
depths
.
permute
(
0
,
3
,
1
,
2
)
preds
[
"masks_render"
]
=
rendered
.
masks
.
permute
(
0
,
3
,
1
,
2
)
preds
[
"nvs_prediction"
]
=
NewViewSynthesisPrediction
(
image_render
=
preds
[
"images_render"
],
depth_render
=
preds
[
"depths_render"
],
mask_render
=
preds
[
"masks_render"
],
)
else
:
raise
AssertionError
(
"Unreachable state"
)
# calc the AD penalty, returns None if autodecoder is not active
ad_penalty
=
self
.
sequence_autodecoder
.
calc_squared_encoding_norm
()
if
ad_penalty
is
not
None
:
preds
[
"loss_autodecoder_norm"
]
=
ad_penalty
# (7) Compute losses
# finally get the optimization objective using self.loss_weights
objective
=
self
.
_get_objective
(
preds
)
if
objective
is
not
None
:
preds
[
"objective"
]
=
objective
return
preds
def
_get_objective
(
self
,
preds
)
->
Optional
[
torch
.
Tensor
]:
"""
A helper function to compute the overall loss as the dot product
of individual loss functions with the corresponding weights.
"""
losses_weighted
=
[
preds
[
k
]
*
float
(
w
)
for
k
,
w
in
self
.
loss_weights
.
items
()
if
(
k
in
preds
and
w
!=
0.0
)
]
if
len
(
losses_weighted
)
==
0
:
warnings
.
warn
(
"No main objective found."
)
return
None
loss
=
sum
(
losses_weighted
)
assert
torch
.
is_tensor
(
loss
)
return
loss
def
visualize
(
self
,
viz
:
Visdom
,
visdom_env_imgs
:
str
,
preds
:
Dict
[
str
,
Any
],
prefix
:
str
,
)
->
None
:
"""
Helper function to visualize the predictions generated
in the forward pass.
Args:
viz: Visdom connection object
visdom_env_imgs: name of visdom environment for the images.
preds: predictions dict like returned by forward()
prefix: prepended to the names of images
"""
if
not
viz
.
check_connection
():
print
(
"no visdom server! -> skipping batch vis"
)
return
idx_image
=
0
title
=
f
"
{
prefix
}
_im
{
idx_image
}
"
vis_utils
.
visualize_basics
(
viz
,
preds
,
visdom_env_imgs
,
title
=
title
)
def
_render
(
self
,
*
,
ray_bundle
:
RayBundle
,
object_mask
:
Optional
[
torch
.
Tensor
],
sampling_mode
:
RenderSamplingMode
,
**
kwargs
,
)
->
RendererOutput
:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
in the image. This is required for the SignedDistanceFunctionRenderer.
sampling_mode: The sampling method to use. Must be a value from the
RenderSamplingMode Enum.
Returns:
An instance of RendererOutput
"""
if
sampling_mode
==
RenderSamplingMode
.
FULL_GRID
and
self
.
chunk_size_grid
>
0
:
return
_apply_chunked
(
self
.
renderer
,
_chunk_generator
(
self
.
chunk_size_grid
,
ray_bundle
,
object_mask
,
self
.
tqdm_trigger_threshold
,
**
kwargs
,
),
lambda
batch
:
_tensor_collator
(
batch
,
ray_bundle
.
lengths
.
shape
[:
-
1
]),
)
else
:
# pyre-fixme[29]: `BaseRenderer` is not a function.
return
self
.
renderer
(
ray_bundle
=
ray_bundle
,
object_mask
=
object_mask
,
**
kwargs
,
)
def
_get_viewpooled_feature_dim
(
self
):
return
(
self
.
feature_aggregator
.
get_aggregated_feature_dim
(
self
.
image_feature_extractor
.
get_feat_dims
()
)
if
self
.
view_pool
else
0
)
def
_check_and_preprocess_renderer_configs
(
self
):
self
.
renderer_MultiPassEmissionAbsorptionRenderer_args
[
"stratified_sampling_coarse_training"
]
=
self
.
raysampler_args
[
"stratified_point_sampling_training"
]
self
.
renderer_MultiPassEmissionAbsorptionRenderer_args
[
"stratified_sampling_coarse_evaluation"
]
=
self
.
raysampler_args
[
"stratified_point_sampling_evaluation"
]
self
.
renderer_SignedDistanceFunctionRenderer_args
[
"render_features_dimensions"
]
=
self
.
render_features_dimensions
self
.
renderer_SignedDistanceFunctionRenderer_args
.
ray_tracer_args
[
"object_bounding_sphere"
]
=
self
.
raysampler_args
[
"scene_extent"
]
def
create_image_feature_extractor
(
self
):
"""
Custom creation function called by run_auto_creation so that the
image_feature_extractor is not created if it is not be needed.
"""
if
self
.
view_pool
:
self
.
image_feature_extractor
=
ResNetFeatureExtractor
(
**
self
.
image_feature_extractor_args
)
def
create_implicit_function
(
self
)
->
None
:
"""
No-op called by run_auto_creation so that self.implicit_function
does not get created. __post_init__ creates the implicit function(s)
in wrappers explicitly in self._implicit_functions.
"""
pass
def
_construct_implicit_functions
(
self
):
"""
After run_auto_creation has been called, the arguments
for each of the possible implicit function methods are
available. `GenericModel` arguments are first validated
based on the custom requirements for each specific
implicit function method. Then the required implicit
function(s) are initialized.
"""
# nerf preprocessing
nerf_args
=
self
.
implicit_function_NeuralRadianceFieldImplicitFunction_args
nerformer_args
=
self
.
implicit_function_NeRFormerImplicitFunction_args
nerf_args
[
"latent_dim"
]
=
nerformer_args
[
"latent_dim"
]
=
(
self
.
_get_viewpooled_feature_dim
()
+
self
.
sequence_autodecoder
.
get_encoding_dim
()
)
nerf_args
[
"color_dim"
]
=
nerformer_args
[
"color_dim"
]
=
self
.
render_features_dimensions
# idr preprocessing
idr
=
self
.
implicit_function_IdrFeatureField_args
idr
[
"feature_vector_size"
]
=
self
.
render_features_dimensions
idr
[
"encoding_dim"
]
=
self
.
sequence_autodecoder
.
get_encoding_dim
()
# srn preprocessing
srn
=
self
.
implicit_function_SRNImplicitFunction_args
srn
.
raymarch_function_args
.
latent_dim
=
(
self
.
_get_viewpooled_feature_dim
()
+
self
.
sequence_autodecoder
.
get_encoding_dim
()
)
# srn_hypernet preprocessing
srn_hypernet
=
self
.
implicit_function_SRNHyperNetImplicitFunction_args
srn_hypernet_args
=
srn_hypernet
.
hypernet_args
srn_hypernet_args
.
latent_dim_hypernet
=
(
self
.
sequence_autodecoder
.
get_encoding_dim
()
)
srn_hypernet_args
.
latent_dim
=
self
.
_get_viewpooled_feature_dim
()
# check that for srn, srn_hypernet, idr we have self.num_passes=1
implicit_function_type
=
registry
.
get
(
ImplicitFunctionBase
,
self
.
implicit_function_class_type
)
if
self
.
num_passes
!=
1
and
not
implicit_function_type
.
allows_multiple_passes
():
raise
ValueError
(
self
.
implicit_function_class_type
+
f
"requires num_passes=1 not
{
self
.
num_passes
}
"
)
if
implicit_function_type
.
requires_pooling_without_aggregation
():
has_aggregation
=
hasattr
(
self
.
feature_aggregator
,
"reduction_functions"
)
if
not
self
.
view_pool
or
has_aggregation
:
raise
ValueError
(
"Chosen implicit function requires view pooling without aggregation."
)
config_name
=
f
"implicit_function_
{
self
.
implicit_function_class_type
}
_args"
config
=
getattr
(
self
,
config_name
,
None
)
if
config
is
None
:
raise
ValueError
(
f
"
{
config_name
}
not present"
)
implicit_functions_list
=
[
ImplicitFunctionWrapper
(
implicit_function_type
(
**
config
))
for
_
in
range
(
self
.
num_passes
)
]
return
torch
.
nn
.
ModuleList
(
implicit_functions_list
)
def
print_loss_weights
(
self
)
->
None
:
"""
Print a table of the loss weights.
"""
print
(
"-------
\n
loss_weights:"
)
for
k
,
w
in
self
.
loss_weights
.
items
():
print
(
f
"
{
k
:
40
s
}
:
{
w
:
1.2
e
}
"
)
print
(
"-------"
)
def
_preprocess_input
(
self
,
image_rgb
:
Optional
[
torch
.
Tensor
],
fg_probability
:
Optional
[
torch
.
Tensor
],
depth_map
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""
Helper function to preprocess the input images and optional depth maps
to apply masking if required.
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
corresponding to the source viewpoints from which features will be extracted
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
of foreground masks with values in [0, 1].
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
Returns:
Modified image_rgb, fg_mask, depth_map
"""
fg_mask
=
fg_probability
if
fg_mask
is
not
None
and
self
.
mask_threshold
>
0.0
:
# threshold masks
warnings
.
warn
(
"Thresholding masks!"
)
fg_mask
=
(
fg_mask
>=
self
.
mask_threshold
).
type_as
(
fg_mask
)
if
self
.
mask_images
and
fg_mask
is
not
None
and
image_rgb
is
not
None
:
# mask the image
warnings
.
warn
(
"Masking images!"
)
image_rgb
=
image_utils
.
mask_background
(
image_rgb
,
fg_mask
,
dim_color
=
1
,
bg_color
=
torch
.
tensor
(
self
.
bg_color
)
)
if
self
.
mask_depths
and
fg_mask
is
not
None
and
depth_map
is
not
None
:
# mask the depths
assert
(
self
.
mask_threshold
>
0.0
),
"Depths should be masked only with thresholded masks"
warnings
.
warn
(
"Masking depths!"
)
depth_map
=
depth_map
*
fg_mask
return
image_rgb
,
fg_mask
,
depth_map
def
_get_view_metrics
(
self
,
raymarched
:
RendererOutput
,
xys
:
torch
.
Tensor
,
image_rgb
:
Optional
[
torch
.
Tensor
]
=
None
,
depth_map
:
Optional
[
torch
.
Tensor
]
=
None
,
fg_probability
:
Optional
[
torch
.
Tensor
]
=
None
,
mask_crop
:
Optional
[
torch
.
Tensor
]
=
None
,
keys_prefix
:
str
=
"loss_"
,
):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
metrics
=
self
.
view_metrics
(
image_sampling_grid
=
xys
,
images_pred
=
raymarched
.
features
,
images
=
image_rgb
,
depths_pred
=
raymarched
.
depths
,
depths
=
depth_map
,
masks_pred
=
raymarched
.
masks
,
masks
=
fg_probability
,
masks_crop
=
mask_crop
,
keys_prefix
=
keys_prefix
,
**
raymarched
.
aux
,
)
if
raymarched
.
prev_stage
:
metrics
.
update
(
self
.
_get_view_metrics
(
raymarched
.
prev_stage
,
xys
,
image_rgb
,
depth_map
,
fg_probability
,
mask_crop
,
keys_prefix
=
(
keys_prefix
+
"prev_stage_"
),
)
)
return
metrics
@
torch
.
no_grad
()
def
_rasterize_mc_samples
(
self
,
xys
:
torch
.
Tensor
,
features
:
torch
.
Tensor
,
depth
:
torch
.
Tensor
,
masks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rasterizes Monte-Carlo features back onto the image.
Args:
xys: B x ... x 2 2D point locations in PyTorch3D NDC convention
features: B x ... x C tensor containing per-point rendered features.
depth: B x ... x 1 tensor containing per-point rendered depth.
"""
ba
=
xys
.
shape
[
0
]
# Flatten the features and xy locations.
features_depth_ras
=
torch
.
cat
(
(
features
.
reshape
(
ba
,
-
1
,
features
.
shape
[
-
1
]),
depth
.
reshape
(
ba
,
-
1
,
1
),
),
dim
=-
1
,
)
xys_ras
=
xys
.
reshape
(
ba
,
-
1
,
2
)
if
masks
is
not
None
:
masks_ras
=
masks
.
reshape
(
ba
,
-
1
,
1
)
else
:
masks_ras
=
None
if
min
(
self
.
render_image_height
,
self
.
render_image_width
)
<=
0
:
raise
ValueError
(
"Need to specify a positive"
" self.render_image_height and self.render_image_width"
" for MC rasterisation."
)
# Estimate the rasterization point radius so that we approximately fill
# the whole image given the number of rasterized points.
pt_radius
=
2.0
*
math
.
sqrt
(
xys
.
shape
[
1
])
# Rasterize the samples.
features_depth_render
,
masks_render
=
rasterize_mc_samples
(
xys_ras
,
features_depth_ras
,
(
self
.
render_image_height
,
self
.
render_image_width
),
radius
=
pt_radius
,
masks
=
masks_ras
,
)
images_render
=
features_depth_render
[:,
:
-
1
]
depths_render
=
features_depth_render
[:,
-
1
:]
return
images_render
,
depths_render
,
masks_render
def
_apply_chunked
(
func
,
chunk_generator
,
tensor_collator
):
"""
Helper function to apply a function on a sequence of
chunked inputs yielded by a generator and collate
the result.
"""
processed_chunks
=
[
func
(
*
chunk_args
,
**
chunk_kwargs
)
for
chunk_args
,
chunk_kwargs
in
chunk_generator
]
return
cat_dataclass
(
processed_chunks
,
tensor_collator
)
def
_tensor_collator
(
batch
,
new_dims
)
->
torch
.
Tensor
:
"""
Helper function to reshape the batch to the desired shape
"""
return
torch
.
cat
(
batch
,
dim
=
1
).
reshape
(
*
new_dims
,
-
1
)
def
_chunk_generator
(
chunk_size
:
int
,
ray_bundle
:
RayBundle
,
object_mask
:
Optional
[
torch
.
Tensor
],
tqdm_trigger_threshold
:
int
,
*
args
,
**
kwargs
,
):
"""
Helper function which yields chunks of rays from the
input ray_bundle, to be used when the number of rays is
large and will not fit in memory for rendering.
"""
(
batch_size
,
*
spatial_dim
,
n_pts_per_ray
,
)
=
ray_bundle
.
lengths
.
shape
# B x ... x n_pts_per_ray
if
n_pts_per_ray
>
0
and
chunk_size
%
n_pts_per_ray
!=
0
:
raise
ValueError
(
f
"chunk_size_grid (
{
chunk_size
}
) should be divisible "
f
"by n_pts_per_ray (
{
n_pts_per_ray
}
)"
)
n_rays
=
math
.
prod
(
spatial_dim
)
# special handling for raytracing-based methods
n_chunks
=
-
(
-
n_rays
*
max
(
n_pts_per_ray
,
1
)
//
chunk_size
)
chunk_size_in_rays
=
-
(
-
n_rays
//
n_chunks
)
iter
=
range
(
0
,
n_rays
,
chunk_size_in_rays
)
if
len
(
iter
)
>=
tqdm_trigger_threshold
:
iter
=
tqdm
.
tqdm
(
iter
)
for
start_idx
in
iter
:
end_idx
=
min
(
start_idx
+
chunk_size_in_rays
,
n_rays
)
ray_bundle_chunk
=
RayBundle
(
origins
=
ray_bundle
.
origins
.
reshape
(
batch_size
,
-
1
,
3
)[:,
start_idx
:
end_idx
],
directions
=
ray_bundle
.
directions
.
reshape
(
batch_size
,
-
1
,
3
)[
:,
start_idx
:
end_idx
],
lengths
=
ray_bundle
.
lengths
.
reshape
(
batch_size
,
math
.
prod
(
spatial_dim
),
n_pts_per_ray
)[:,
start_idx
:
end_idx
],
xys
=
ray_bundle
.
xys
.
reshape
(
batch_size
,
-
1
,
2
)[:,
start_idx
:
end_idx
],
)
extra_args
=
kwargs
.
copy
()
if
object_mask
is
not
None
:
extra_args
[
"object_mask"
]
=
object_mask
.
reshape
(
batch_size
,
-
1
,
1
)[
:,
start_idx
:
end_idx
]
yield
[
ray_bundle_chunk
,
*
args
],
extra_args
pytorch3d/implicitron/models/implicit_function/__init__.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
pytorch3d/implicitron/models/implicit_function/base.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
pytorch3d.implicitron.tools.config
import
ReplaceableBase
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.renderer.implicit
import
RayBundle
class
ImplicitFunctionBase
(
ABC
,
ReplaceableBase
):
def
__init__
(
self
):
super
().
__init__
()
@
abstractmethod
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
**
kwargs
,
):
raise
NotImplementedError
()
@
staticmethod
def
allows_multiple_passes
()
->
bool
:
"""
Returns True if this implicit function allows
multiple passes.
"""
return
False
@
staticmethod
def
requires_pooling_without_aggregation
()
->
bool
:
"""
Returns True if this implicit function needs
pooling without aggregation.
"""
return
False
def
on_bind_args
(
self
)
->
None
:
"""
Called when the custom args are fixed in the main model forward pass.
"""
pass
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment