Unverified Commit cdd2142d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

implicitron v0 (#1133)


Co-authored-by: default avatarJeremy Francis Reizenstein <bottler@users.noreply.github.com>
parent 0e377c68
defaults:
- repro_singleseq_base.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05
defaults:
- repro_singleseq_srn.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0
defaults:
- repro_singleseq_wce_base
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05
defaults:
- repro_singleseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0
defaults:
- repro_singleseq_base
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
""""
This file is the entry point for launching experiments with Implicitron.
Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
pass, visualizations and metric printing
Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:
```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```
To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.
Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
- Checkpoints:
Model, optimizer and stats are stored in the directory
named by the `exp_dir` key from the config file / CLI parameters.
- Stats
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.
"""
import copy
import json
import logging
import os
import random
import time
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import hydra
import lpips
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData,
)
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"):
raise ValueError(
f"Hydra version {hydra.__version__} is too old."
" (Implicitron requires version 1.1 or later.)"
)
try:
# only makes sense in FAIR cluster
import pytorch3d.implicitron.fair_cluster.slurm # noqa: F401
except ModuleNotFoundError:
pass
def init_model(
cfg: DictConfig,
force_load: bool = False,
clear_stats: bool = False,
load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
"""
Returns an instance of `GenericModel`.
If `cfg.resume` is set or `force_load` is true,
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
will return the model with initial weights, unless `force_load` is passed,
in which case a FileNotFoundError is raised.
Args:
force_load: If true, force load model from checkpoint even if
cfg.resume is false.
clear_stats: If true, clear the stats object loaded from checkpoint
load_model_only: If true, load only the model weights from checkpoint
and do not load the state of the optimizer and stats.
Returns:
model: The model with optionally loaded weights from checkpoint
stats: The stats structure (optionally loaded from checkpoint)
optimizer_state: The optimizer state dict containing
`state` and `param_groups` keys (optionally loaded from checkpoint)
Raise:
FileNotFoundError if `force_load` is passed but checkpoint is not found.
"""
# Initialize the model
if cfg.architecture == "generic":
model = GenericModel(**cfg.generic_model_args)
else:
raise ValueError(f"No such arch {cfg.architecture}.")
# Determine the network outputs that should be logged
if hasattr(model, "log_vars"):
log_vars = copy.deepcopy(list(model.log_vars))
else:
log_vars = ["objective"]
visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"
# Init the stats struct
stats = Stats(
log_vars,
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=cfg.visdom_server,
visdom_port=cfg.visdom_port,
)
# Retrieve the last checkpoint
if cfg.resume_epoch > 0:
model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(cfg.exp_dir)
optimizer_state = None
if model_path is not None:
logger.info("found previous model %s" % model_path)
if force_load or cfg.resume:
logger.info(" -> resuming")
if load_model_only:
model_state_dict = torch.load(model_io.get_model_path(model_path))
stats_load, optimizer_state = None, None
else:
model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path
)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
try:
# TODO: fix on creation of the buffers
# after the hack above, this will not pass in most cases
# ... but this is fine for now
model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info("Cant load state dict in strict mode! -> trying non-strict")
model.load_state_dict(model_state_dict, strict=False)
model.log_vars = log_vars
else:
logger.info(" -> but not resuming -> starting from scratch")
elif force_load:
raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")
return model, stats, optimizer_state
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: bool = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
# Add the max epochs here
scheduler.max_epochs = max_epochs
optimizer.zero_grad()
return optimizer, scheduler
def trainvalidate(
model,
stats,
epoch,
loader,
optimizer,
validation,
bp_var: str = "objective",
metric_print_interval: int = 5,
visualize_interval: int = 100,
visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0,
device: str = "cuda:0",
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
model: The model module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
epoch: The index of the current epoch
loader: The dataloader to use for the loop
optimizer: The optimizer module optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env_root: The name of the visdom environment to use for plotting
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
device: The device on which to run the model.
Returns:
None
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, batch in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = batch.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if visualize_interval > 0 and it % visualize_interval == 0:
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
loss.backward()
if clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), clip_grad
)
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / total_norm}."
)
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# set the debug mode
if cfg.detect_anomaly:
logger.info("Anomaly detection!")
torch.autograd.set_detect_anomaly(cfg.detect_anomaly)
# create the output folder
os.makedirs(cfg.exp_dir, exist_ok=True)
_seed_all_random_engines(cfg.seed)
remove_unused_components(cfg)
# dump the exp config to the exp dir
try:
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
OmegaConf.save(config=cfg, f=cfg_filename)
except PermissionError:
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
# init the model
model, stats, optimizer_state = init_model(cfg)
start_epoch = stats.epoch + 1
# move model to gpu
model.to(device)
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
return
# init the optimizer
optimizer, scheduler = init_optimizer(
model,
optimizer_state=optimizer_state,
last_epoch=start_epoch,
**cfg.solver_args,
)
# check the scheduler and stats have been initialized correctly
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
past_scheduler_lrs = []
# loop through epochs
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines(cfg.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.info(f"scheduler lr = {cur_lr:1.2e}")
past_scheduler_lrs.append(cur_lr)
# train loop
trainvalidate(
model,
stats,
epoch,
dataloaders["train"],
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# val loop (optional)
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders["val"],
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# eval loop (optional)
if (
"test" in dataloaders
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders["test"], device=device)
assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required
# save model
if cfg.store_checkpoints:
if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if "test" not in dataloaders:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
all_source_cameras = (
_get_all_source_cameras(datasets["train"])
if eval_task == "singlesequence"
else None
)
results = run_eval(
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
)
# add the evaluation epoch to the results
for r in results:
r["eval_epoch"] = int(stats.epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data):
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
nvs_prediction,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _get_all_source_cameras(
dataset: ImplicitronDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
# Set the device
device = "cpu"
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
device = f"cuda:{cfg.gpu_idx}"
logger.info(f"Running experiment on device: {device}")
run_training(cfg, device)
if __name__ == "__main__":
experiment()
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
import math
import os
import random
import sys
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.implicitron.tools.video_writer import VideoWriter
from pytorch3d.implicitron.tools.vis_utils import (
get_visdom_connection,
make_depth_image,
)
from tqdm import tqdm
def render_sequence(
dataset: ImplicitronDataset,
sequence_name: str,
model: torch.nn.Module,
video_path,
n_eval_cameras=40,
fps=20,
max_angle=2 * math.pi,
trajectory_type="circular_lsq_fit",
trajectory_scale=1.1,
scene_center=(0.0, 0.0, 0.0),
up=(0.0, -1.0, 0.0),
traj_offset=0.0,
n_source_views=9,
viz_env="debug",
visdom_show_preds=False,
visdom_server="http://127.0.0.1",
visdom_port=8097,
num_workers=10,
seed=None,
video_resize=None,
):
if seed is None:
seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name]
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
print(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
test_cameras = generate_eval_video_cameras(
train_cameras,
time=time,
n_eval_cams=n_eval_cameras,
trajectory_type=trajectory_type,
trajectory_scale=trajectory_scale,
scene_center=scene_center,
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
)
# sample the source views reproducibly
with torch.random.fork_rng():
torch.manual_seed(seed)
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
# add the first dummy view that will get replaced with the target camera
source_views_i = Fu.pad(source_views_i, [1, 0])
source_views = [seq_idx[i] for i in source_views_i.tolist()]
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
preds_total = []
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
# set the first batch camera to the target camera
for k in ("R", "T", "focal_length", "principal_point"):
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
# Move to cuda
net_input = batch.cuda()
with torch.no_grad():
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
# Render the predictions to images
rendered_pred = images_from_preds(preds)
preds_total.append(rendered_pred)
# show the preds every 5% of the export iterations
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
)
print(f"Exporting videos for sequence {sequence_name} ...")
generate_prediction_videos(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
fps=fps,
video_path=video_path,
resize=video_resize,
)
def _load_whole_dataset(dataset, idx, num_workers=10):
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
num_workers=num_workers,
shuffle=False,
collate_fn=FrameData.collate,
)
return next(iter(load_all_dataloader))
def images_from_preds(preds):
imout = {}
for k in (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
"_all_source_images",
):
if k == "_all_source_images" and "image_rgb" in preds:
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
v = _stack_images(src_ims, None)[None]
else:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
mask_resize = Fu.interpolate(
preds["masks_render"],
size=preds[k].shape[2:],
mode="nearest",
)
v = make_depth_image(preds[k], mask_resize)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
imout[k] = v.detach().cpu()
return imout
def _stack_images(ims, size):
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
n_add = H * W - ba
if n_add > 0:
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
ims = ims.view(H, W, *ims.shape[1:])
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
if size is not None:
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
return cated.clamp(0.0, 1.0)
def show_predictions(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
n_samples=10,
one_image_width=200,
):
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
pred_all = []
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples = min(n_samples, len(preds))
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
for predi in pred_idx:
# Make the concatentation for the same camera vertically
pred_all.append(
torch.cat(
[
torch.nn.functional.interpolate(
preds[predi][k].cpu(),
scale_factor=one_image_width / preds[predi][k].shape[3],
mode="bilinear",
).clamp(0.0, 1.0)
for k in predicted_keys
],
dim=2,
)
)
# Concatenate the images horizontally
pred_all_cat = torch.cat(pred_all, dim=3)[0]
viz.image(
pred_all_cat,
win="show_predictions",
env=viz_env,
opts={"title": f"pred_{sequence_name}"},
)
def generate_prediction_videos(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
fps=20,
video_path="/tmp/video",
resize=None,
):
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
assert isinstance(preds, list)
# make sure the target video directory exists
os.makedirs(os.path.dirname(video_path), exist_ok=True)
# init a video writer for each predicted key
vws = {}
for k in predicted_keys:
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
for rendered_pred in tqdm(preds):
for k in predicted_keys:
vws[k].write_frame(
rendered_pred[k][0].detach().cpu().numpy(),
resize=resize,
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
print(f"Generated {vws[k].out_path}.")
viz.video(
videofile=vws[k].out_path,
env=viz_env,
win=k, # we reuse the same window otherwise visdom dies
opts={"title": sequence_name + " " + k},
)
def export_scenes(
exp_dir: str = "",
restrict_sequence_name: Optional[str] = None,
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
visdom_port=8097,
visdom_show_preds: bool = False,
visdom_env: Optional[str] = None,
gpu_idx: int = 0,
):
# In case an output directory is specified use it. If no output_directory
# is specified create a vis folder inside the experiment directory
if output_directory is None:
output_directory = os.path.join(exp_dir, "vis")
else:
output_directory = output_directory
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Set the random seeds
torch.manual_seed(0)
np.random.seed(0)
# Get the config from the experiment_directory,
# and overwrite relevant fields
config = _get_config_from_experiment_directory(exp_dir)
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
# Load the previously trained model
model, _, _ = init_model(config, force_load=True, load_model_only=True)
model.cuda()
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys():
with torch.no_grad():
render_sequence(
dataset,
sequence_name,
model,
video_path="{}/video".format(output_directory),
n_source_views=n_source_views,
visdom_show_preds=visdom_show_preds,
n_eval_cameras=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
viz_env=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
video_resize=video_size,
)
def _get_config_from_experiment_directory(experiment_directory):
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
return config
def main(argv):
# automatically parses arguments of export_scenes
cfg = OmegaConf.create(get_default_args(export_scenes))
cfg.update(OmegaConf.from_cli())
with torch.no_grad():
export_scenes(**cfg)
if __name__ == "__main__":
main(sys.argv)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Sequence
import torch
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler
def dataloader_zoo(
datasets: Dict[str, ImplicitronDatasetBase],
dataset_name: str = "co3d_singlesequence",
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,),
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dict[str, torch.utils.data.DataLoader]:
"""
Returns a set of dataloaders for a given set of datasets.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")
dataloaders = {}
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
for dataset_set, dataset in datasets.items():
num_samples = {
"train": dataset_len,
"val": dataset_len_val,
"test": None,
}[dataset_set]
if dataset_set == "test":
batch_sampler = dataset.get_eval_batches()
else:
assert num_samples is not None
num_samples = len(dataset) if num_samples <= 0 else num_samples
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=num_samples,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
dataloaders[dataset_set] = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=FrameData.collate,
)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
return dataloaders
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from typing import Any, Dict, List, Optional, Sequence
from iopath.common.file_io import PathManager
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
from .utils import (
DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST,
DATASET_TYPE_TRAIN,
DATASET_TYPE_UNKNOWN,
)
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
"default": {
"box_crop": True,
"box_crop_context": 0.3,
"image_width": 800,
"image_height": 800,
"remove_empty_masks": True,
}
}
# fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv",
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
"umbrella", "wineglass", "laptop",
"hotdog", "stopsign", "frisbee", "baseballglove",
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
"handbag", "chair", "keyboard", "car", "motorcycle",
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
"toaster", "mouse", "toytrain", "book", "toytruck",
"orange", "broccoli", "plant", "teddybear",
"suitcase", "bench", "ball", "cake",
"vase", "hydrant", "apple", "donut",
]))
# fmt: on
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
category: str = "DEFAULT",
limit_to: int = -1,
limit_sequences_to: int = -1,
n_frames_per_sequence: int = -1,
test_on_train: bool = False,
load_point_clouds: bool = False,
mask_images: bool = False,
mask_depths: bool = False,
restrict_sequence_name: Sequence[str] = (),
test_restrict_sequence_id: int = -1,
assert_single_seq: bool = False,
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Dict[str, ImplicitronDatasetBase]:
"""
Generates the training / validation and testing dataset objects.
Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence.
test_on_train: Construct validation and test datasets from
the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'.
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call.
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
datasets = {}
# TODO:
# - implement loading multiple categories
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
dataset_name,
test_on_train,
only_test_set,
)
# load the evaluation batches
task = dataset_name.split("_")[-1]
batch_indices_path = os.path.join(
dataset_root,
category,
f"eval_batches_{task}.json",
)
if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError("Please specify a correct dataset_root folder.")
with open(batch_indices_path, "r") as f:
eval_batch_index = json.load(f)
if task == "singlesequence":
assert (
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
), (
"Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'"
+ " training and evaluation."
)
assert len(restrict_sequence_name) == 0, (
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
# a sort-stable set() equivalent:
eval_batches_sequence_names = list(
{b[0][0]: None for b in eval_batch_index}.keys()
)
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
eval_batch_index = [
b for b in eval_batch_index if b[0][0] == eval_sequence_name
]
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
for dataset, subsets in set_names_mapping.items():
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
assert os.path.isfile(frame_file)
sequence_file = os.path.join(
dataset_root, category, "sequence_annotations.jgz"
)
assert os.path.isfile(sequence_file)
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
assert os.path.isfile(subset_lists_file)
# TODO: maybe directly in param list
params = {
**copy.deepcopy(aux_dataset_kwargs),
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"n_frames_per_sequence": n_frames_per_sequence
if dataset == "train"
else -1,
"subsets": subsets,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"pick_sequence": restrict_sequence_name,
"path_manager": path_manager,
}
datasets[dataset] = ImplicitronDataset(**params)
if dataset == "test":
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
datasets[dataset].eval_batches = datasets[
dataset
].seq_frame_index_to_dataset_index(eval_batch_index)
if assert_single_seq:
# check theres only one sequence in all datasets
assert (
len(
{
e["frame_annotation"].sequence_name
for dset in datasets.values()
for e in dset.frame_annots
}
)
<= 1
), "Multiple sequences loaded but expected one"
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if test_on_train:
datasets["val"] = datasets["train"]
datasets["test"] = datasets["train"]
return datasets
def _get_co3d_set_names_mapping(
dataset_name: str,
test_on_train: bool,
only_test: bool,
) -> Dict[str, List[str]]:
"""
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
to the names of the corresponding subsets in the CO3D dataset
("test_known"/"test_unseen"/"train_known"/"train_unseen").
"""
single_seq = dataset_name == "co3d_singlesequence"
if only_test:
set_names_mapping = {}
else:
set_names_mapping = {
"train": [
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
+ "_"
+ DATASET_TYPE_KNOWN
]
}
if not test_on_train:
prefixes = [DATASET_TYPE_TEST]
if not single_seq:
prefixes.append(DATASET_TYPE_TRAIN)
set_names_mapping.update(
{
dset: [
p + "_" + t
for p in prefixes
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
]
for dset in ["val", "test"]
}
)
return set_names_mapping
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
import gzip
import hashlib
import json
import os
import random
import warnings
from collections import defaultdict
from dataclasses import dataclass, field, fields
from itertools import islice
from pathlib import Path
from typing import (
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
)
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
from . import types
@dataclass
class FrameData:
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box capturing the object in the
format (x0, y0, width, height).
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def keys(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
@dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
The dataset is made up of frames, and the frames are grouped into sequences.
Each sequence has a name (a string).
(A sequence could be a video, or a set of images of one scene.)
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
"""
seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
"""
If the sequences in the dataset are videos rather than
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idx: frame index in self
Returns:
tuple of
- frame index in video
- timestamp of frame in video
"""
raise ValueError("This dataset does not contain videos.")
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
class FrameAnnotsEntry(TypedDict):
subset: Optional[str]
frame_annotation: types.FrameAnnotation
@dataclass(eq=False)
class ImplicitronDataset(ImplicitronDatasetBase):
"""
A class for the Common Objects in 3D (CO3D) dataset.
Args:
frame_annotations_file: A zipped json file containing metadata of the
frames in the dataset, serialized List[types.FrameAnnotation].
sequence_annotations_file: A zipped json file containing metadata of the
sequences in the dataset, serialized List[types.SequenceAnnotation].
subset_lists_file: A json file containing the lists of frames corresponding
corresponding to different subsets (e.g. train/val/test) of the dataset;
format: {subset: (sequence_name, frame_id, file_path)}.
subsets: Restrict frames/sequences only to the given list of subsets
as defined in subset_lists_file (see above).
limit_to: Limit the dataset to the first #limit_to frames (after other
filters have been applied).
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences (after other sequence filters have been
applied but before frame-based filters).
pick_sequence: A list of sequence names to restrict the dataset to.
exclude_sequence: A list of the names of the sequences to exclude.
limit_category_to: Restrict the dataset to the given list of categories.
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
depth values used for evaluation (the points consistent across views).
load_masks: Enable loading frame foreground masks.
load_point_clouds: Enable loading sequence-level point clouds.
max_points: Cap on the number of loaded points in the point cloud;
if reached, they are randomly sampled without replacement.
mask_images: Whether to mask the images with the loaded foreground masks;
0 value is used for background.
mask_depths: Whether to mask the depth maps with the loaded foreground
masks; 0 value is used for background.
image_height: The height of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
image_width: The width of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
box_crop: Enable cropping of the image around the bounding box inferred
from the foreground region of the loaded segmentation mask; masks
and depth maps are cropped accordingly; cameras are corrected.
box_crop_mask_thr: The threshold used to separate pixels into foreground
and background based on the foreground_probability mask; if no value
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask after thresholding (see box_crop_mask_thr).
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
frames in each sequences uniformly without replacement if it has
more frames than that; applied before other frame-level filters.
seed: The seed of the random generator sampling #n_frames_per_sequence
random frames per sequence.
sort_frames: Enable frame annotations sorting to group frames from the
same sequences together and order them by timestamps
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
"""
frame_annotations_type: ClassVar[
Type[types.FrameAnnotation]
] = types.FrameAnnotation
path_manager: Optional[PathManager] = None
frame_annotations_file: str = ""
sequence_annotations_file: str = ""
subset_lists_file: str = ""
subsets: Optional[List[str]] = None
limit_to: int = 0
limit_sequences_to: int = 0
pick_sequence: Sequence[str] = ()
exclude_sequence: Sequence[str] = ()
limit_category_to: Sequence[int] = ()
dataset_root: str = ""
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
load_masks: bool = True
load_point_clouds: bool = False
max_points: int = 0
mask_images: bool = False
mask_depths: bool = False
image_height: Optional[int] = 256
image_width: Optional[int] = 256
box_crop: bool = False
box_crop_mask_thr: float = 0.4
box_crop_context: float = 1.0
remove_empty_masks: bool = False
n_frames_per_sequence: int = -1
seed: int = 0
sort_frames: bool = False
eval_batches: Optional[List[List[int]]] = None
frame_annots: List[FrameAnnotsEntry] = field(init=False)
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
def __post_init__(self) -> None:
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`.
self.subset_to_image_path = None
self._load_frames()
self._load_sequences()
if self.sort_frames:
self._sort_frames()
self._load_subset_lists()
self._filter_db() # also computes sequence indices
print(str(self))
def seq_frame_index_to_dataset_index(
self,
seq_frame_index: Union[
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
],
) -> List[List[int]]:
"""
Obtain indices into the dataset object given a list of frames specified as
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
"""
# TODO: check the frame numbers are unique
_dataset_seq_frame_n_index = {
seq: {
self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx
}
for seq, seq_idx in self.seq_to_idx.items()
}
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
idx = _dataset_seq_frame_n_index[seq_name][frame_no]
if path is not None:
# Check that the loaded frame path is consistent
# with the one stored in self.frame_annots.
assert os.path.normpath(
self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath(
path
), f"Inconsistent batch {seq_name, frame_no, path}."
return idx
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index]
return batches_idx
def __str__(self) -> str:
return f"ImplicitronDataset #frames={len(self.frame_annots)}"
def __len__(self) -> int:
return len(self.frame_annots)
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"]
def __getitem__(self, index) -> FrameData:
if index >= len(self.frame_annots):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
sequence_name=entry.sequence_name,
sequence_category=self.seq_annots[entry.sequence_name].category,
camera_quality_score=_safe_as_tensor(
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
torch.float,
),
point_cloud_quality_score=_safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
# The rest of the fields are optional
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
clamp_bbox_xyxy,
) = self._load_crop_fg_probability(entry)
scale = 1.0
if self.load_images and entry.image is not None:
# original image size
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
(
frame_data.image_rgb,
frame_data.image_path,
frame_data.mask_crop,
scale,
) = self._load_crop_images(
entry, frame_data.fg_probability, clamp_bbox_xyxy
)
if self.load_depths and entry.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
if entry.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(
entry,
scale,
clamp_bbox_xyxy,
)
if self.load_point_clouds and point_cloud is not None:
frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
self.dataset_root, point_cloud.path
)
frame_data.sequence_point_cloud = _load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
return frame_data
def _load_crop_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[
Optional[torch.Tensor],
Optional[str],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = (
None,
None,
None,
None,
)
if (self.load_masks or self.box_crop) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path)
mask = _load_mask(self._local_path(full_path))
if mask.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
)
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop:
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
def _load_crop_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = _load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.box_crop:
assert clamp_bbox_xyxy is not None
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability
return image_rgb, path, mask_crop, scale
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
clamp_bbox_xyxy: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
)
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability
if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = _load_depth_mask(self._local_path(mask_path))
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_mask_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
)
depth_mask = _crop_around_box(
depth_mask, depth_mask_bbox_xyxy, mask_path
)
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
else:
depth_mask = torch.ones_like(depth_map)
return depth_map, path, depth_mask
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
scale: float,
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
)
# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")
# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale
if self.box_crop:
assert clamp_bbox_xyxy is not None
principal_point_px -= clamp_bbox_xyxy[:2]
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if self.image_height is None or self.image_width is None:
out_size = list(reversed(entry.image.size))
else:
out_size = [self.image_width, self.image_height]
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()
# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _load_frames(self) -> None:
print(f"Loading Co3D frames from {self.frame_annotations_file}.")
local_file = self._local_path(self.frame_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
zipfile, List[self.frame_annotations_type]
)
if not frame_annots_list:
raise ValueError("Empty dataset!")
self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
]
def _load_sequences(self) -> None:
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
local_file = self._local_path(self.sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
if not seq_annots:
raise ValueError("Empty sequences file!")
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None:
print(f"Loading Co3D subset lists from {self.subset_lists_file}.")
if not self.subset_lists_file:
return
with open(self._local_path(self.subset_lists_file), "r") as f:
subset_to_seq_frame = json.load(f)
frame_path_to_subset = {
path: subset
for subset, frames in subset_to_seq_frame.items()
for _, _, path in frames
}
for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None
)
if frame["subset"] is None:
warnings.warn(
"Subset lists are given but don't include "
+ frame["frame_annotation"].image.path
)
def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp
self.frame_annots = sorted(
self.frame_annots,
key=lambda f: (
f["frame_annotation"].sequence_name,
f["frame_annotation"].frame_timestamp or 0,
),
)
def _filter_db(self) -> None:
if self.remove_empty_masks:
print("Removing images with empty masks.")
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
mask = frame_annot.mask
if mask is None:
return False
if mask.mass is None:
raise ValueError(msg)
return mask.mass > 1
self.frame_annots = [
frame
for frame in self.frame_annots
if positive_mass(frame["frame_annotation"])
]
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!!
subsets = self.subsets
if subsets:
if not self.subset_lists_file:
raise ValueError(
"Subset filter is on but subset_lists_file was not given"
)
print(f"Limitting Co3D dataset to the '{subsets}' subsets.")
# truncate the list of subsets to the valid one
self.frame_annots = [
entry for entry in self.frame_annots if entry["subset"] in subsets
]
if len(self.frame_annots) == 0:
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0:
print(f"Limitting dataset to categories: {self.limit_category_to}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if entry.category in self.limit_category_to
}
# sequence filters
for prefix in ("pick", "exclude"):
orig_len = len(self.seq_annots)
attr = f"{prefix}_sequence"
arr = getattr(self, attr)
if len(arr) > 0:
print(f"{attr}: {str(arr)}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if (name in arr) == (prefix == "pick")
}
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0:
self.seq_annots = dict(
islice(self.seq_annots.items(), self.limit_sequences_to)
)
# retain only frames from retained sequences
self.frame_annots = [
f
for f in self.frame_annots
if f["frame_annotation"].sequence_name in self.seq_annots
]
self._invalidate_indexes()
if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
for seq, seq_indices in self.seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed
seq_idx_shuffled = random.Random(seed).sample(
sorted(seq_indices), len(seq_indices)
)
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
print(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
)
self.frame_annots = self.frame_annots[: self.limit_to]
self._invalidate_indexes(filter_seq_annots=True)
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
# update seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
self._invalidate_seq_to_idx()
if filter_seq_annots:
self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
self.seq_to_idx = seq_to_idx
def _resize_image(
self, image, mode="bilinear"
) -> Tuple[torch.Tensor, float, torch.Tensor]:
image_height, image_width = self.image_height, self.image_width
if image_height is None or image_width is None:
# skip the resizing
imre_ = torch.from_numpy(image)
return imre_, 1.0, torch.ones_like(imre_[:1])
# takes numpy array, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
torch.from_numpy(image)[None],
# pyre-ignore[6]
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, self.image_height, self.image_width)
mask[:, 0 : imre.shape[1] - 1, 0 : imre.shape[2] - 1] = 1.0
return imre_, minscale, mask
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
out: List[Tuple[int, float]] = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
def _seq_name_to_seed(seq_name) -> int:
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def _load_16big_png_depth(depth_png) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def _load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask
def _load_depth_mask(path) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)
return m[None] # fake feature channel
def _load_depth(path, scale_adjustment) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = _load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
def _load_mask(path) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel
def _get_1d_bounds(arr) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1]
def _get_bbox_from_mask(
mask, thr, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def _get_clamp_bbox(
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = ""
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {impath}!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(bbox[2:], 2)
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
# +1 because upper bound is not inclusive
return bbox
def _crop_around_box(tensor, bbox, impath: str = ""):
# bbox is xyxy, where the upper bound is corrected with +1
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1])
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2])
bbox = bbox.round().long()
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size
def _safe_as_tensor(data, dtype):
if data is None:
return None
return torch.tensor(data, dtype=dtype)
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
from .implicitron_dataset import ImplicitronDatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config?
class SceneBatchSampler(Sampler[List[int]]):
"""
A class for sampling training batches with a controlled composition
of sequences.
"""
dataset: ImplicitronDatasetBase
batch_size: int
num_batches: int
# the sampler first samples a random element k from this list and then
# takes k random frames per sequence
images_per_seq_options: Sequence[int]
# if True, will sample a contiguous interval of frames in the sequence
# it first finds the connected segments within the sequence of sufficient length,
# then samples a random pivot element among them and ideally uses it as a middle
# of the temporal window, shifting the borders where necessary.
# This strategy mitigates the bias against shorter segments and their boundaries.
sample_consecutive_frames: bool = False
# if a number > 0, then used to define the maximum difference in frame_number
# of neighbouring frames when forming connected segments; otherwise the whole
# sequence is considered a segment regardless of frame numbers
consecutive_frames_max_gap: int = 0
# same but for timestamps if they are available
consecutive_frames_max_gap_seconds: float = 0.1
seq_names: List[str] = field(init=False)
def __post_init__(self) -> None:
if self.batch_size <= 0:
raise ValueError(
"batch_size should be a positive integral value, "
f"but got batch_size={self.batch_size}"
)
if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys())
def __len__(self) -> int:
return self.num_batches
def __iter__(self) -> Iterator[List[int]]:
for batch_idx in range(len(self)):
batch = self._sample_batch(batch_idx)
yield batch
def _sample_batch(self, batch_idx) -> List[int]:
n_per_seq = np.random.choice(self.images_per_seq_options)
n_seqs = -(-self.batch_size // n_per_seq) # round up
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
if self.sample_consecutive_frames:
frame_idx = []
for seq in chosen_seq:
segment_index = self._build_segment_index(
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq:
frame_idx.append(segment)
else:
start = np.clip(idx - n_per_seq // 2, 0, len(segment) - n_per_seq)
frame_idx.append(segment[start : start + n_per_seq])
else:
frame_idx = [
_capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
)
for seq in chosen_seq
]
frame_idx = np.concatenate(frame_idx)[: self.batch_size].tolist()
if len(frame_idx) < self.batch_size:
warnings.warn(
"Batch size smaller than self.batch_size!"
+ " (This is fine for experiments with a single scene and viewpooling)"
)
return frame_idx
def _build_segment_index(
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
belongs to index is the frame's index within that segment.
Segment references are repeated but the memory is shared.
"""
if (
self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0
):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
seq_frame_indices, self.dataset
)
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size)
if not segments:
raise AssertionError("Empty segments after culling")
else:
segments = [seq_frame_indices]
# build an index of segment for random selection of a pivot frame
segment_index = [
(segment, i) for segment in segments for i in range(len(segment))
]
return segment_index
def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]]
) -> List[List[int]]:
if (
self.consecutive_frames_max_gap <= 0
and self.consecutive_frames_max_gap_seconds <= 0.0
):
raise AssertionError("This function is only needed for non-trivial max_gap")
segments = []
last_no = -self.consecutive_frames_max_gap - 1 # will trigger a new segment
last_ts = -self.consecutive_frames_max_gap_seconds - 1.0
for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no:
raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given"
)
if (
no - last_no > self.consecutive_frames_max_gap > 0
or ts - last_ts > self.consecutive_frames_max_gap_seconds > 0.0
): # new group
segments.append([idx])
else:
segments[-1].append(idx)
last_no = no
last_ts = ts
return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths)
if max_len < min_size:
return [longest_segment]
return [segment for segment in segments if len(segment) >= min_size]
def _capped_random_choice(x, size, replace: bool = True):
"""
if replace==True
randomly chooses from x `size` elements without replacement if len(x)>size
else allows replacement and selects `size` elements again.
if replace==False
randomly chooses from x `min(len(x), size)` elements without replacement
"""
len_x = x if isinstance(x, int) else len(x)
if replace:
return np.random.choice(x, size=size, replace=len_x < size)
else:
return np.random.choice(x, size=min(size, len_x), replace=False)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import gzip
import json
import sys
from dataclasses import MISSING, Field, dataclass
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
import numpy as np
_X = TypeVar("_X")
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls):
return getattr(cls, "__origin__", None)
def get_args(cls):
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")
TF3 = Tuple[float, float, float]
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int] # TODO: rename size_hw?
@dataclass
class DepthAnnotation:
# path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment`
path: str
# a factor to convert png values to actual depth: `depth = png * scale_adjustment`
scale_adjustment: float
# path to png file, relative w.r.t. dataset_root, storing binary `depth` mask
mask_path: Optional[str]
@dataclass
class MaskAnnotation:
# path to png file storing (Prob(fg | pixel) * 255)
path: str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass: Optional[float] = None
@dataclass
class ViewpointAnnotation:
# In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T
R: Tuple[TF3, TF3, TF3]
T: TF3
focal_length: Tuple[float, float]
principal_point: Tuple[float, float]
intrinsics_format: str = "ndc_norm_image_bounds"
# Defines the co-ordinate system where focal_length and principal_point live.
# Possible values: ndc_isotropic | ndc_norm_image_bounds (default)
# ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries
# correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ
# ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has
# the range [-1, 1], and the longer one has the range [-s, s]; s >= 1,
# where s is the aspect ratio. The scale is same along x and y.
@dataclass
class FrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
depth: Optional[DepthAnnotation] = None
mask: Optional[MaskAnnotation] = None
viewpoint: Optional[ViewpointAnnotation] = None
@dataclass
class PointCloudAnnotation:
# path to ply file with points only, relative w.r.t. dataset_root
path: str
# the bigger the better
quality_score: float
n_points: Optional[int]
@dataclass
class VideoAnnotation:
# path to the original video file, relative w.r.t. dataset_root
path: str
# length of the video in seconds
length: float
@dataclass
class SequenceAnnotation:
sequence_name: str
category: str
video: Optional[VideoAnnotation] = None
point_cloud: Optional[PointCloudAnnotation] = None
# the bigger the better
viewpoint_quality_score: Optional[float] = None
def dump_dataclass(obj: Any, f: IO, binary: bool = False) -> None:
"""
Args:
f: Either a path to a file, or a file opened for writing.
obj: A @dataclass or collection hierarchy including dataclasses.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
f.write(json.dumps(_asdict_rec(obj)).encode("utf8"))
else:
json.dump(_asdict_rec(obj), f)
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
if isinstance(asdict, list):
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
else:
res = _dataclass_from_dict(asdict, cls)
return res
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
elif any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
# otherwise, we dispatch by the type of the provided annotation to convert to
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls._field_types.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp)
for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp)
for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
vals = np.split(list(all_vals_res), indices[:-1])
return [cls(zip(*k, v)) for k, v in zip(keys, vals)]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
def _dataclass_from_dict(d, typeannot):
cls = get_origin(typeannot) or typeannot
if d is None:
return d
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
types = cls._field_types.values()
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
elif issubclass(cls, (list, tuple)):
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(d)
return cls(_dataclass_from_dict(v, tp) for v, tp in zip(d, types))
elif issubclass(cls, dict):
key_t, val_t = get_args(typeannot)
return cls(
(_dataclass_from_dict(k, key_t), _dataclass_from_dict(v, val_t))
for k, v in d.items()
)
elif not dataclasses.is_dataclass(typeannot):
return d
assert dataclasses.is_dataclass(cls)
fieldtypes = {f.name: _unwrap_type(f.type) for f in dataclasses.fields(typeannot)}
return cls(**{k: _dataclass_from_dict(v, fieldtypes[k]) for k, v in d.items()})
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _asdict_rec(obj):
return dataclasses._asdict_inner(obj, dict)
def dump_dataclass_jgzip(outfile: str, obj: Any) -> None:
"""
Dumps obj to a gzipped json outfile.
Args:
obj: A @dataclass or collection hiererchy including dataclasses.
outfile: The path to the output file.
"""
with gzip.GzipFile(outfile, "wb") as f:
dump_dataclass(obj, cast(IO, f), binary=True)
def load_dataclass_jgzip(outfile, cls):
"""
Loads a dataclass from a gzipped json outfile.
Args:
outfile: The path to the loaded file.
cls: The type annotation of the loaded dataclass.
Returns:
loaded_dataclass: The loaded dataclass.
"""
with gzip.GzipFile(outfile, "rb") as f:
return load_dataclass(cast(IO, f), cls, binary=True)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import torch
DATASET_TYPE_TRAIN = "train"
DATASET_TYPE_TEST = "test"
DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen"
def is_known_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a known frame.
"""
return torch.tensor(
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
dtype=torch.bool,
device=device,
)
def is_train_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a training frame.
"""
return torch.tensor(
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
dtype=torch.bool,
device=device,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, cast
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .implicitron_dataset import FrameData, ImplicitronDataset
def get_implicitron_sequence_pointcloud(
dataset: ImplicitronDataset,
sequence_name: Optional[str] = None,
mask_points: bool = True,
max_frames: int = -1,
num_workers: int = 0,
load_dataset_point_cloud: bool = False,
) -> Tuple[Pointclouds, FrameData]:
"""
Make a point cloud by sampling random points from each frame the dataset.
"""
if len(dataset) == 0:
raise ValueError("The dataset is empty.")
if not dataset.load_depths:
raise ValueError("The dataset has to load depths (dataset.load_depths=True).")
if mask_points and not dataset.load_masks:
raise ValueError(
"For mask_points=True, the dataset has to load masks"
+ " (dataset.load_masks=True)."
)
# setup the indices of frames loaded from the dataset db
sequence_entries = list(range(len(dataset)))
if sequence_name is not None:
sequence_entries = [
ei
for ei in sequence_entries
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]
if len(sequence_entries) == 0:
raise ValueError(
f'There are no dataset entries for sequence name "{sequence_name}".'
)
# subsample loaded frames if needed
if (max_frames > 0) and (len(sequence_entries) > max_frames):
sequence_entries = [
sequence_entries[i]
for i in torch.randperm(len(sequence_entries))[:max_frames].sort().values
]
# take only the part of the dataset corresponding to the sequence entries
sequence_dataset = torch.utils.data.Subset(dataset, sequence_entries)
# load the required part of the dataset
loader = torch.utils.data.DataLoader(
sequence_dataset,
batch_size=len(sequence_dataset),
shuffle=False,
num_workers=num_workers,
collate_fn=FrameData.collate,
)
frame_data = next(iter(loader)) # there's only one batch
# scene point cloud
if load_dataset_point_cloud:
if not dataset.load_point_clouds:
raise ValueError(
"For load_dataset_point_cloud=True, the dataset has to"
+ " load point clouds (dataset.load_point_clouds=True)."
)
point_cloud = frame_data.sequence_point_cloud
else:
point_cloud = get_rgbd_point_cloud(
frame_data.camera,
frame_data.image_rgb,
frame_data.depth_map,
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if frame_data.fg_probability is not None
else None,
mask_points=mask_points,
)
return point_cloud, frame_data
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import os
from typing import Optional, cast
import lpips
import torch
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,
eval_batch,
pretty_print_nvs_metrics,
summarize_nvs_eval_results,
)
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm
def main() -> None:
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for multisequence/singlesequence tasks for several categories.
The evaluation is conducted on the same data as in [1] and, hence, the results
are directly comparable to the numbers reported in [1].
References:
[1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone,
P. Labatut, D. Novotny:
Common Objects in 3D: Large-Scale Learning
and Evaluation of Real-life 3D Category Reconstruction
"""
task_results = {}
for task in ("singlesequence", "multisequence"):
task_results[task] = []
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
category_result = evaluate_dbir_for_category(
category, task=task, single_sequence_id=single_sequence_id
)
print("")
print(
f"Results for task={task}; category={category};"
+ (
f" sequence={single_sequence_id}:"
if single_sequence_id is not None
else ":"
)
)
pretty_print_nvs_metrics(category_result)
print("")
task_results[task].append(category_result)
_print_aggregate_results(task, task_results)
for task in task_results:
_print_aggregate_results(task, task_results)
def evaluate_dbir_for_category(
category: str = "apple",
bg_color: float = 0.0,
task: str = "singlesequence",
single_sequence_id: Optional[int] = None,
num_workers: int = 16,
):
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for a given task, category, and sequence (in case task=='singlesequence').
Args:
category: Object category.
bg_color: Background color of the renders.
task: Evaluation task. Either singlesequence or multisequence.
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
num_workers: The number of workers for the employed dataloaders.
Returns:
category_result: A dictionary of quantitative metrics.
"""
single_sequence_id = single_sequence_id if single_sequence_id is not None else -1
torch.manual_seed(42)
if task not in ["multisequence", "singlesequence"]:
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
datasets = dataset_zoo(
category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == "singlesequence",
dataset_name=f"co3d_{task}",
test_on_train=False,
load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id,
)
dataloaders = dataloader_zoo(
datasets,
dataset_name=f"co3d_{task}",
)
test_dataset = datasets["test"]
test_dataloader = dataloaders["test"]
if task == "singlesequence":
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
all_source_cameras = _get_all_source_cameras(
test_dataset, sequence_name, num_workers=num_workers
)
else:
all_source_cameras = None
image_size = cast(ImplicitronDataset, test_dataset).image_width
if image_size is None:
raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model
model = ModelDBIR(
image_size=image_size,
bg_color=bg_color,
max_points=int(1e5),
)
model.cuda()
# init the lpips model for eval
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.cuda()
per_batch_eval_results = []
print("Evaluating DBIR model ...")
for frame_data in tqdm(test_dataloader):
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
eval_batch(
frame_data,
nvs_prediction,
bg_color=bg_color,
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _print_aggregate_results(task, task_results) -> None:
"""
Prints the aggregate metrics for a given task.
"""
aggregate_task_result = aggregate_nvs_results(task_results[task])
print("")
print(f"Aggregate results for task={task}:")
pretty_print_nvs_metrics(aggregate_task_result)
print("")
def _get_all_source_cameras(
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx = dataset.seq_to_idx[sequence_name]
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
return source_cameras
if __name__ == "__main__":
main()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.tools import vis_utils
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
from pytorch3d.implicitron.tools.image_utils import mask_background
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth, iou, rgb_l1
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.implicitron.tools.vis_utils import make_depth_image
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate
from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
@dataclass
class NewViewSynthesisPrediction:
"""
Holds the tensors that describe a result of synthesizing new views.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
@dataclass
class _Visualizer:
image_render: torch.Tensor
image_rgb_masked: torch.Tensor
depth_render: torch.Tensor
depth_map: torch.Tensor
depth_mask: torch.Tensor
visdom_env: str = "eval_debug"
_viz: Visdom = field(init=False)
def __post_init__(self):
self._viz = vis_utils.get_visdom_connection()
def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
):
self._viz.images(
torch.cat(
(
self.image_render,
self.image_rgb_masked,
loss_mask_now.repeat(1, 3, 1, 1),
),
dim=3,
),
env=self.visdom_env,
win=metric_name,
opts={"title": f"{metric_name}_{loss_value:1.2f}"},
)
def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
):
self._viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
make_depth_image(self.depth_map, loss_mask_now),
),
dim=3,
),
env=self.visdom_env,
win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
)
self._viz.images(
loss_mask_now,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
self._viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
)
# show the 3D plot
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
# `TensorProperties`.
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
loss_mask_now.device
)
pcl_pred = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
pcl_gt = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {
pn: p
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
if int(p.num_points_per_cloud()) > 0
}
plotlyplot = plot_scene(
{f"pcl{name_postfix}": _pcls},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1,
)
self._viz.plotlyplot(
plotlyplot,
env=self.visdom_env,
win=f"pcl{name_postfix}",
)
def eval_batch(
frame_data: FrameData,
nvs_prediction: NewViewSynthesisPrediction,
bg_color: Union[torch.Tensor, str, float] = "black",
mask_thr: float = 0.5,
lpips_model=None,
visualize: bool = False,
visualize_visdom_env: str = "eval_debug",
break_after_visualising: bool = True,
source_cameras: Optional[List[CamerasBase]] = None,
) -> Dict[str, Any]:
"""
Produce performance metrics for a single batch of new-view synthesis
predictions.
Given a set of known views (for which frame_data.frame_type.endswith('known')
is True), a new-view synthesis method (NVS) is tasked to generate new views
of the scene from the viewpoint of the target views (for which
frame_data.frame_type.endswith('known') is False). The resulting
synthesized new views, stored in `nvs_prediction`, are compared to the
target ground truth in `frame_data` in terms of geometry and appearance
resulting in a dictionary of metrics returned by the `eval_batch` function.
Args:
frame_data: A FrameData object containing the input to the new view
synthesis method.
nvs_prediction: The data describing the synthesized new views.
bg_color: The background color of the generated new views and the
ground truth.
lpips_model: A pre-trained model for evaluating the LPIPS metric.
visualize: If True, visualizes the results to Visdom.
source_cameras: A list of all training cameras for evaluating the
difficulty of the target views.
Returns:
results: A dictionary holding evaluation metrics.
Throws:
ValueError if frame_data does not have frame_type, camera, or image_rgb
ValueError if the batch has a mix of training and test samples
ValueError if the batch frames are not [unseen, known, known, ...]
ValueError if one of the required fields in nvs_prediction is missing
"""
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
frame_type = frame_data.frame_type
if frame_type is None:
raise ValueError("Frame type has not been set.")
# we check that all those fields are not None but Pyre can't infer that properly
# TODO: assign to local variables
if frame_data.image_rgb is None:
raise ValueError("Image is not in the evaluation batch.")
if frame_data.camera is None:
raise ValueError("Camera is not in the evaluation batch.")
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
raise ValueError("One of the required predicted fields is missing")
# obtain copies to make sure we dont edit the original data
nvs_prediction = copy.deepcopy(nvs_prediction)
frame_data = copy.deepcopy(frame_data)
# mask the ground truth depth in case frame_data contains the depth mask
if frame_data.depth_map is not None and frame_data.depth_mask is not None:
frame_data.depth_map *= frame_data.depth_mask
if not isinstance(frame_type, list): # not batch FrameData
frame_type = [frame_type]
is_train = is_train_frame(frame_type)
if not (is_train[0] == is_train).all():
raise ValueError("All frames in the eval batch have to be either train/test.")
# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)
if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
raise ValueError(
"For evaluation the first element of the batch has to be"
+ " a target view while the rest should be source views."
) # TODO: do we need to enforce this?
# take only the first (target image)
for k in REQUIRED_NVS_PREDICTION_FIELDS:
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
for k in [
"depth_map",
"image_rgb",
"fg_probability",
"mask_crop",
]:
if not hasattr(frame_data, k) or getattr(frame_data, k) is None:
continue
setattr(frame_data, k, getattr(frame_data, k)[:1])
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
warnings.warn("Empty or missing depth map in evaluation!")
# eval all results in the resolution of the frame_data image
# pyre-fixme[16]: `Optional` has no attribute `shape`.
image_resol = list(frame_data.image_rgb.shape[2:])
# threshold the masks to make ground truth binary masks
mask_fg, mask_crop = [
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
]
image_rgb_masked = mask_background(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
frame_data.image_rgb,
mask_fg,
bg_color=bg_color,
)
# resize to the target resolution
for k in REQUIRED_NVS_PREDICTION_FIELDS:
imode = "bilinear" if k == "image_render" else "nearest"
val = getattr(nvs_prediction, k)
setattr(
nvs_prediction,
k,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
)
# clamp predicted images
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
if visualize:
visualizer = _Visualizer(
image_render=image_render,
image_rgb_masked=image_rgb_masked,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
# `Optional[torch.Tensor]`.
depth_render=nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
depth_mask=frame_data.depth_mask[:1],
visdom_env=visualize_visdom_env,
)
results: Dict[str, Any] = {}
results["iou"] = iou(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.mask_render,
mask_fg,
mask=mask_crop,
)
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("", "_fg")):
loss_mask_now = mask_crop * loss_fg_mask
for rgb_metric_name, rgb_metric_fun in zip(
("psnr", "rgb_l1"), (calc_psnr, rgb_l1)
):
metric_name = rgb_metric_name + name_postfix
results[metric_name] = rgb_metric_fun(
image_render,
image_rgb_masked,
mask=loss_mask_now,
)
if visualize:
visualizer.show_rgb(
results[metric_name].item(), metric_name, loss_mask_now
)
if name_postfix == "_fg":
# only record depth metrics for the foreground
_, abs_ = eval_depth(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
# `Optional[torch.Tensor]`.
frame_data.depth_map,
get_best_scale=True,
mask=loss_mask_now,
crop=5,
)
results["depth_abs" + name_postfix] = abs_.mean()
if visualize:
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
if break_after_visualising:
import pdb
pdb.set_trace()
if lpips_model is not None:
im1, im2 = [
2.0 * im.clamp(0.0, 1.0) - 1.0
for im in (image_rgb_masked, nvs_prediction.image_render)
]
results["lpips"] = lpips_model.forward(im1, im2).item()
# convert all metrics to floats
results = {k: float(v) for k, v in results.items()}
if source_cameras is None:
# pyre-fixme[16]: Optional has no attribute __getitem__
source_cameras = frame_data.camera[torch.where(is_known)[0]]
results["meta"] = {
# calculate the camera difficulties and add to results
"camera_difficulty": calculate_camera_difficulties(
frame_data.camera[0],
source_cameras,
)[0].item(),
# store the size of the batch (corresponds to n_src_views+1)
"batch_size": int(is_known.numel()),
# store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type": str(frame_data.frame_type[0]),
}
return results
def average_per_batch_results(
results_per_batch: List[Dict[str, Any]],
idx: Optional[torch.Tensor] = None,
) -> dict:
"""
Average a list of per-batch metrics `results_per_batch`.
Optionally, if `idx` is given, only a subset of the per-batch
metrics, indexed by `idx`, is averaged.
"""
result_keys = list(results_per_batch[0].keys())
result_keys.remove("meta")
if idx is not None:
results_per_batch = [results_per_batch[i] for i in idx]
if len(results_per_batch) == 0:
return {k: float("NaN") for k in result_keys}
return {
k: float(np.array([r[k] for r in results_per_batch]).mean())
for k in result_keys
}
def calculate_camera_difficulties(
cameras_target: CamerasBase,
cameras_source: CamerasBase,
) -> torch.Tensor:
"""
Calculate the difficulties of the target cameras, given a set of known
cameras `cameras_source`.
Returns:
a tensor of shape (len(cameras_target),)
"""
ious = [
volumetric_camera_overlaps(
join_cameras_as_batch(
# pyre-fixme[6]: Expected `CamerasBase` for 1st param but got
# `Optional[pytorch3d.renderer.utils.TensorProperties]`.
[cameras_target[cami], cameras_source.to(cameras_target.device)]
)
)[0, :]
for cami in range(cameras_target.R.shape[0])
]
camera_difficulties = torch.stack(
[_reduce_camera_iou_overlap(iou[1:]) for iou in ious]
)
return camera_difficulties
def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tensor:
"""
Calculate the final camera difficulty by computing the average of the
ious of the two most similar cameras.
Returns:
single-element Tensor
"""
# pyre-ignore[16] topk not recognized
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def get_camera_difficulty_bin_edges(task: str):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == "multisequence":
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == "singlesequence":
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: str = "singlesequence",
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task.
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
Either 'singlesequence' or 'multisequence'.
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
aux_out: A dictionary holding a set of auxiliary results.
"""
n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = []
if task == "singlesequence":
eval_sets = [None]
# assert n_batches==100
elif task == "multisequence":
eval_sets = ["train", "test"]
# assert n_batches==1000
else:
raise ValueError(task)
batch_sizes = torch.tensor(
[r["meta"]["batch_size"] for r in per_batch_eval_results]
).long()
camera_difficulty = torch.tensor(
[r["meta"]["camera_difficulty"] for r in per_batch_eval_results]
).float()
is_train = is_train_frame([r["meta"]["frame_type"] for r in per_batch_eval_results])
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
n_diff_edges = diff_bin_edges.numel()
# add per set averages
for SET in eval_sets:
if SET is None:
# task=='singlesequence'
ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test"
else:
# task=='multisequence'
ok_set = is_train == int(SET == "train")
set_name = SET
# eval each difficulty bin, including a full average result (diff_bin=None)
for diff_bin in [None, *list(range(n_diff_edges - 1))]:
if diff_bin is None:
# average over all results
in_bin = ok_set
diff_bin_name = "all"
else:
b1, b2 = diff_bin_edges[diff_bin : (diff_bin + 2)]
in_bin = ok_set & (camera_difficulty > b1) & (camera_difficulty <= b2)
diff_bin_name = diff_bin_names[diff_bin]
bin_results = average_per_batch_results(
per_batch_eval_results, idx=torch.where(in_bin)[0]
)
results.append(
{
"subset": set_name,
"subsubset": f"diff={diff_bin_name}",
"metrics": bin_results,
}
)
if task == "multisequence":
# split based on n_src_views
n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS:
ok_src = ok_set & (n_src_views == n_src)
n_src_results = average_per_batch_results(
per_batch_eval_results,
idx=torch.where(ok_src)[0],
)
results.append(
{
"subset": set_name,
"subsubset": f"n_src={int(n_src)}",
"metrics": n_src_results,
}
)
aux_out = {"results": results}
return flatten_nvs_results(results), aux_out
def _get_flat_nvs_metric_key(result, metric_name) -> str:
metric_key_postfix = f"|subset={result['subset']}|{result['subsubset']}"
metric_key = f"{metric_name}{metric_key_postfix}"
return metric_key
def flatten_nvs_results(results):
"""
Takes input `results` list of dicts of the form:
```
[
{
'subset':'train/test/...',
'subsubset': 'src=1/src=2/...',
'metrics': nvs_eval_metrics}
},
...
]
```
And converts to a flat dict as follows:
{
'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics,
...
}
"""
results_flat = {}
for result in results:
for metric_name, metric_val in result["metrics"].items():
metric_key = _get_flat_nvs_metric_key(result, metric_name)
assert metric_key not in results_flat
results_flat[metric_key] = metric_val
return results_flat
def pretty_print_nvs_metrics(results) -> None:
subsets, subsubsets = [
_ordered_set([r[k] for r in results]) for k in ("subset", "subsubset")
]
metrics = _ordered_set([metric for r in results for metric in r["metrics"]])
for subset in subsets:
tab = {}
for metric in metrics:
tab[metric] = []
header = ["metric"]
for subsubset in subsubsets:
metric_vals = [
r["metrics"][metric]
for r in results
if r["subsubset"] == subsubset and r["subset"] == subset
]
if len(metric_vals) > 0:
tab[metric].extend(metric_vals)
header.extend(subsubsets)
if any(len(v) > 0 for v in tab.values()):
print(f"===== NVS results; subset={subset} =====")
print(
tabulate(
[[metric, *v] for metric, v in tab.items()],
# pyre-fixme[61]: `header` is undefined, or not always defined.
headers=header,
)
)
def _ordered_set(list_):
return list(OrderedDict((i, 0) for i in list_).keys())
def aggregate_nvs_results(task_results):
"""
Aggregate nvs results.
For singlescene, this averages over all categories and scenes,
for multiscene, the average is over all per-category results.
"""
task_results_cat = [r_ for r in task_results for r_ in r]
subsets, subsubsets = [
_ordered_set([r[k] for r in task_results_cat]) for k in ("subset", "subsubset")
]
metrics = _ordered_set(
[metric for r in task_results_cat for metric in r["metrics"]]
)
average_results = []
for subset in subsets:
for subsubset in subsubsets:
metrics_lists = [
r["metrics"]
for r in task_results_cat
if r["subsubset"] == subsubset and r["subset"] == subset
]
avg_metrics = {}
for metric in metrics:
avg_metrics[metric] = float(
np.nanmean(
np.array([metric_list[metric] for metric_list in metrics_lists])
)
)
average_results.append(
{
"subset": subset,
"subsubset": subsubset,
"metrics": avg_metrics,
}
)
return average_results
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable
# TODO: probabilistic embeddings?
class Autodecoder(Configurable, torch.nn.Module):
"""
Autodecoder module
Settings:
encoding_dim: Embedding dimension for the decoder.
n_instances: The maximum number of instances stored by the autodecoder.
init_scale: Scale factor for the initial autodecoder weights.
ignore_input: If `True`, optimizes a single code for any input.
"""
encoding_dim: int = 0
n_instances: int = 0
init_scale: float = 1.0
ignore_input: bool = False
def __post_init__(self):
super().__init__()
if self.n_instances <= 0:
# Do not init the codes at all in case we have 0 instances.
return
self._autodecoder_codes = torch.nn.Embedding(
self.n_instances,
self.encoding_dim,
scale_grad_by_freq=True,
)
with torch.no_grad():
# weight has been initialised from Normal(0, 1)
self._autodecoder_codes.weight *= self.init_scale
self._sequence_map = self._build_sequence_map()
# Make sure to register hooks for correct handling of saving/loading
# the module's _sequence_map.
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
self._register_state_dict_hook(_save_sequence_map_hook)
def _build_sequence_map(
self, sequence_map_dict: Optional[Dict[str, int]] = None
) -> Dict[str, int]:
"""
Args:
sequence_map_dict: A dictionary used to initialize the sequence_map.
Returns:
sequence_map: a dictionary of key: id pairs.
"""
# increments the counter when asked for a new value
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
if sequence_map_dict is not None:
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
# Since this is done in the original order, it should generate
# the same set of key:id pairs. We check this with an assert to be sure.
for x, x_id in sequence_map_dict.items():
x_id_ = sequence_map[x]
assert x_id == x_id_
return sequence_map
def calc_squared_encoding_norm(self):
if self.n_instances <= 0:
return None
return (self._autodecoder_codes.weight ** 2).mean()
def get_encoding_dim(self) -> int:
if self.n_instances <= 0:
return 0
return self.encoding_dim
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
"""
Args:
x: A batch of `N` sequence identifiers. Either a long tensor of size
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
are hashed to codes (without collisions).
Returns:
codes: A tensor of shape `(N, self.encoding_dim)` containing the
sequence-specific autodecoder codes.
"""
if self.n_instances == 0:
return None
if self.ignore_input:
x = ["singleton"]
if isinstance(x[0], str):
try:
x = torch.tensor(
# pyre-ignore[29]
[self._sequence_map[elem] for elem in x],
dtype=torch.long,
device=next(self.parameters()).device,
)
except StopIteration:
raise ValueError("Not enough n_instances in the autodecoder")
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._autodecoder_codes(x)
def _load_sequence_map_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
Returns:
Constructed sequence_map if it exists in the state_dict
else raises a warning only.
"""
sequence_map_key = prefix + "_sequence_map"
if sequence_map_key in state_dict:
sequence_map_dict = state_dict.pop(sequence_map_key)
self._sequence_map = self._build_sequence_map(
sequence_map_dict=sequence_map_dict
)
else:
warnings.warn("No sequence map in Autodecoder state dict!")
def _save_sequence_map_hook(
self,
state_dict,
prefix,
local_metadata,
) -> None:
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
"""
sequence_map_key = prefix + "_sequence_map"
sequence_map_dict = dict(self._sequence_map.items())
state_dict[sequence_map_key] = sequence_map_dict
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple
import torch
import tqdm
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
from .autodecoder import Autodecoder
from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa
NeRFormerImplicitFunction,
NeuralRadianceFieldImplicitFunction,
)
from .implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
)
from .metrics import ViewMetrics
from .renderer.base import (
BaseRenderer,
EvaluationMode,
ImplicitFunctionWrapper,
RendererOutput,
RenderSamplingMode,
)
from .renderer.lstm_renderer import LSTMRenderer # noqa
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooling.feature_aggregation import FeatureAggregatorBase
from .view_pooling.view_sampling import ViewSampler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
# pyre-ignore: 13
class GenericModel(Configurable, torch.nn.Module):
"""
GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists
of the following sequence of 7 steps (steps 2–4 are normally
skipped in overfitting scenario, since conditioning on source views
does not add much information; otherwise they should be present altogether):
(1) Ray Sampling
------------------
Rays are sampled from an image grid based on the target view(s).
│_____________
│ │
│ ▼
│ (2) Feature Extraction (optional)
│ -----------------------
│ A feature extractor (e.g. a convolutional
│ neural net) is used to extract image features
│ from the source view(s).
│ │
│ ▼
│ (3) View Sampling (optional)
│ ------------------
│ Image features are sampled at the 2D projections
│ of a set of 3D points along each of the sampled
│ target rays from (1).
│ │
│ ▼
│ (4) Feature Aggregation (optional)
│ ------------------
│ Aggregate features and masks sampled from
│ image view(s) in (3).
│ │
│____________▼
(5) Implicit Function Evaluation
------------------
Evaluate the implicit function(s) at the sampled ray points
(optionally pass in the aggregated image features from (4)).
(6) Rendering
------------------
Render the image into the target cameras by raymarching along
the sampled rays and aggregating the colors and densities
output by the implicit function in (5).
(7) Loss Computation
------------------
Compute losses based on the predicted target image(s).
The `forward` function of GenericModel executes
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
can be customized by intializing a subclass of the appropriate
baseclass and adding the newly created module to the registry.
Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins
for more details on how to create and register a custom component.
In the config .yaml files for experiments, the parameters below are
contained in the `generic_model_args` node. As GenericModel
derives from Configurable, the input arguments are
parsed by the run_auto_creation function to initialize the
necessary member modules. Please see implicitron_trainer/README.md
for more details on this process.
Args:
mask_images: Whether or not to mask the RGB image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
mask_depths: Whether or not to mask the depth image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
render_image_width: Width of the output image to render
render_image_height: Height of the output image to render
mask_threshold: If greater than 0.0, the foreground mask is
thresholded by this value before being applied to the RGB/Depth images
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
splatting onto an image grid. Default: False.
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
view_pool: If True, features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
num_passes: The specified implicit_function is initialized num_passes
times and run sequentially.
chunk_size_grid: The total number of points which can be rendered
per chunk. This is used to compute the number of rays used
per chunk when the chunked version of the renderer is used (in order
to fit rendering on all rays in memory)
render_features_dimensions: The number of output features to render.
Defaults to 3, corresponding to RGB images.
n_train_target_views: The number of cameras to render into at training
time; first `n_train_target_views` in the batch are considered targets,
the rest are sources.
sampling_mode_training: The sampling method to use during training. Must be
a value from the RenderSamplingMode Enum.
sampling_mode_evaluation: Same as above but for evaluation.
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
of the image (referred to as the global_code) that can be used to model aspects of
the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and
concatenate to the other inputs for the implicit function.
raysampler: An instance of RaySampler which is used to emit
rays from the target view(s).
renderer_class_type: The name of the renderer class which is available in the global
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
implicit_function_class_type: The type of implicit function to use which
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
are initialised to be in self._implicit_functions.
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
for `ViewMetrics` class for available loss functions.
log_vars: A list of variable names which should be logged.
The names should correspond to a subset of the keys of the
dict `preds` output by the `forward` function.
"""
mask_images: bool = True
mask_depths: bool = True
render_image_width: int = 400
render_image_height: int = 400
mask_threshold: float = 0.5
output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
view_pool: bool = False
num_passes: int = 1
chunk_size_grid: int = 4096
render_features_dimensions: int = 3
tqdm_trigger_threshold: int = 16
n_train_target_views: int = 1
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
# ---- autodecoder settings
sequence_autodecoder: Autodecoder
# ---- raysampler
raysampler: RaySampler
# ---- renderer configs
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
renderer: BaseRenderer
# ---- view sampling settings - used if view_pool=True
# (This is only created if view_pool is False)
image_feature_extractor: ResNetFeatureExtractor
view_sampler: ViewSampler
# ---- ---- view sampling feature aggregator settings
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
# ---- implicit function settings
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
# This is just a model, never constructed.
# The actual implicit functions live in self._implicit_functions
implicit_function: ImplicitFunctionBase
# ---- loss weights
loss_weights: Dict[str, float] = field(
default_factory=lambda: {
"loss_rgb_mse": 1.0,
"loss_prev_stage_rgb_mse": 1.0,
"loss_mask_bce": 0.0,
"loss_prev_stage_mask_bce": 0.0,
}
)
# ---- variables to be logged (logger automatically ignores if not computed)
log_vars: List[str] = field(
default_factory=lambda: [
"loss_rgb_psnr_fg",
"loss_rgb_psnr",
"loss_rgb_mse",
"loss_rgb_huber",
"loss_depth_abs",
"loss_depth_abs_fg",
"loss_mask_neg_iou",
"loss_mask_bce",
"loss_mask_beta_prior",
"loss_eikonal",
"loss_density_tv",
"loss_depth_neg_penalty",
"loss_autodecoder_norm",
# metrics that are only logged in 2+stage renderes
"loss_prev_stage_rgb_mse",
"loss_prev_stage_rgb_psnr_fg",
"loss_prev_stage_rgb_psnr",
"loss_prev_stage_mask_bce",
*STD_LOG_VARS,
]
)
def __post_init__(self):
super().__init__()
self.view_metrics = ViewMetrics()
self._check_and_preprocess_renderer_configs()
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
self.raysampler_args["image_width"] = self.render_image_width
self.raysampler_args["image_height"] = self.render_image_height
run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions()
self.print_loss_weights()
def forward(
self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
the first `min(B, n_train_target_views)` images are considered targets and
are used to supervise the renders; the rest corresponding to the source
viewpoints from which features will be extracted.
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
to the viewpoints of target images, from which the rays will be sampled,
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
Returns:
preds: A dictionary containing all outputs of the forward pass including the
rendered images, depths, masks, losses and other metrics.
"""
image_rgb, fg_probability, depth_map = self._preprocess_input(
image_rgb, fg_probability, depth_map
)
# Obtain the batch size from the camera as this is the only required input.
batch_size = camera.R.shape[0]
# Determine the number of target views, i.e. cameras we render into.
n_targets = (
1
if evaluation_mode == EvaluationMode.EVALUATION
else batch_size
if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size)
)
# Select the target cameras.
target_cameras = camera[list(range(n_targets))]
# Determine the used ray sampling mode.
sampling_mode = RenderSamplingMode(
self.sampling_mode_training
if evaluation_mode == EvaluationMode.TRAINING
else self.sampling_mode_evaluation
)
# (1) Sample rendering rays with the ray sampler.
ray_bundle: RayBundle = self.raysampler(
target_cameras,
evaluation_mode,
mask=mask_crop[:n_targets]
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None,
)
# custom_args hold additional arguments to the implicit function.
custom_args = {}
if self.view_pool:
if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling")
# (2) Extract features for the image
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
# (3) Sample features and masks at the ray points
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
) # returns feats_sampled, masks_sampled
# (4) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
*curried_view_sampler(pts=pts),
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args["fun_viewpool"] = curried_view_pool
global_code = None
if self.sequence_autodecoder.n_instances > 0:
if sequence_name is None:
raise ValueError("sequence_name must be provided for autodecoder.")
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
custom_args["global_code"] = global_code
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for func in self._implicit_functions:
func.bind_args(**custom_args)
object_mask: Optional[torch.Tensor] = None
if fg_probability is not None:
sampled_fb_prob = rend_utils.ndc_grid_sample(
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
)
object_mask = sampled_fb_prob > 0.5
# (5)-(6) Implicit function evaluation and Rendering
rendered = self._render(
ray_bundle=ray_bundle,
sampling_mode=sampling_mode,
evaluation_mode=evaluation_mode,
implicit_functions=self._implicit_functions,
object_mask=object_mask,
)
# Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the
# bound arguments.
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for func in self._implicit_functions:
func.unbind_args()
preds = self._get_view_metrics(
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
depth_map=None if depth_map is None else depth_map[:n_targets],
fg_probability=None
if fg_probability is None
else fg_probability[:n_targets],
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
)
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
if self.output_rasterized_mc:
# Visualize the monte-carlo pixel renders by splatting onto
# an image grid.
(
preds["images_render"],
preds["depths_render"],
preds["masks_render"],
) = self._rasterize_mc_samples(
ray_bundle.xys,
rendered.features,
rendered.depths,
masks=rendered.masks,
)
elif sampling_mode == RenderSamplingMode.FULL_GRID:
preds["images_render"] = rendered.features.permute(0, 3, 1, 2)
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
preds["nvs_prediction"] = NewViewSynthesisPrediction(
image_render=preds["images_render"],
depth_render=preds["depths_render"],
mask_render=preds["masks_render"],
)
else:
raise AssertionError("Unreachable state")
# calc the AD penalty, returns None if autodecoder is not active
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
if ad_penalty is not None:
preds["loss_autodecoder_norm"] = ad_penalty
# (7) Compute losses
# finally get the optimization objective using self.loss_weights
objective = self._get_objective(preds)
if objective is not None:
preds["objective"] = objective
return preds
def _get_objective(self, preds) -> Optional[torch.Tensor]:
"""
A helper function to compute the overall loss as the dot product
of individual loss functions with the corresponding weights.
"""
losses_weighted = [
preds[k] * float(w)
for k, w in self.loss_weights.items()
if (k in preds and w != 0.0)
]
if len(losses_weighted) == 0:
warnings.warn("No main objective found.")
return None
loss = sum(losses_weighted)
assert torch.is_tensor(loss)
return loss
def visualize(
self,
viz: Visdom,
visdom_env_imgs: str,
preds: Dict[str, Any],
prefix: str,
) -> None:
"""
Helper function to visualize the predictions generated
in the forward pass.
Args:
viz: Visdom connection object
visdom_env_imgs: name of visdom environment for the images.
preds: predictions dict like returned by forward()
prefix: prepended to the names of images
"""
if not viz.check_connection():
print("no visdom server! -> skipping batch vis")
return
idx_image = 0
title = f"{prefix}_im{idx_image}"
vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title)
def _render(
self,
*,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
sampling_mode: RenderSamplingMode,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
in the image. This is required for the SignedDistanceFunctionRenderer.
sampling_mode: The sampling method to use. Must be a value from the
RenderSamplingMode Enum.
Returns:
An instance of RendererOutput
"""
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
return _apply_chunked(
self.renderer,
_chunk_generator(
self.chunk_size_grid,
ray_bundle,
object_mask,
self.tqdm_trigger_threshold,
**kwargs,
),
lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]),
)
else:
# pyre-fixme[29]: `BaseRenderer` is not a function.
return self.renderer(
ray_bundle=ray_bundle,
object_mask=object_mask,
**kwargs,
)
def _get_viewpooled_feature_dim(self):
return (
self.feature_aggregator.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
if self.view_pool
else 0
)
def _check_and_preprocess_renderer_configs(self):
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_training"
] = self.raysampler_args["stratified_point_sampling_training"]
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_evaluation"
] = self.raysampler_args["stratified_point_sampling_evaluation"]
self.renderer_SignedDistanceFunctionRenderer_args[
"render_features_dimensions"
] = self.render_features_dimensions
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
"object_bounding_sphere"
] = self.raysampler_args["scene_extent"]
def create_image_feature_extractor(self):
"""
Custom creation function called by run_auto_creation so that the
image_feature_extractor is not created if it is not be needed.
"""
if self.view_pool:
self.image_feature_extractor = ResNetFeatureExtractor(
**self.image_feature_extractor_args
)
def create_implicit_function(self) -> None:
"""
No-op called by run_auto_creation so that self.implicit_function
does not get created. __post_init__ creates the implicit function(s)
in wrappers explicitly in self._implicit_functions.
"""
pass
def _construct_implicit_functions(self):
"""
After run_auto_creation has been called, the arguments
for each of the possible implicit function methods are
available. `GenericModel` arguments are first validated
based on the custom requirements for each specific
implicit function method. Then the required implicit
function(s) are initialized.
"""
# nerf preprocessing
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
)
nerf_args["color_dim"] = nerformer_args[
"color_dim"
] = self.render_features_dimensions
# idr preprocessing
idr = self.implicit_function_IdrFeatureField_args
idr["feature_vector_size"] = self.render_features_dimensions
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
# srn preprocessing
srn = self.implicit_function_SRNImplicitFunction_args
srn.raymarch_function_args.latent_dim = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
)
# srn_hypernet preprocessing
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
srn_hypernet_args = srn_hypernet.hypernet_args
srn_hypernet_args.latent_dim_hypernet = (
self.sequence_autodecoder.get_encoding_dim()
)
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
# check that for srn, srn_hypernet, idr we have self.num_passes=1
implicit_function_type = registry.get(
ImplicitFunctionBase, self.implicit_function_class_type
)
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
raise ValueError(
self.implicit_function_class_type
+ f"requires num_passes=1 not {self.num_passes}"
)
if implicit_function_type.requires_pooling_without_aggregation():
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
if not self.view_pool or has_aggregation:
raise ValueError(
"Chosen implicit function requires view pooling without aggregation."
)
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
config = getattr(self, config_name, None)
if config is None:
raise ValueError(f"{config_name} not present")
implicit_functions_list = [
ImplicitFunctionWrapper(implicit_function_type(**config))
for _ in range(self.num_passes)
]
return torch.nn.ModuleList(implicit_functions_list)
def print_loss_weights(self) -> None:
"""
Print a table of the loss weights.
"""
print("-------\nloss_weights:")
for k, w in self.loss_weights.items():
print(f"{k:40s}: {w:1.2e}")
print("-------")
def _preprocess_input(
self,
image_rgb: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Helper function to preprocess the input images and optional depth maps
to apply masking if required.
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
corresponding to the source viewpoints from which features will be extracted
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
of foreground masks with values in [0, 1].
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
Returns:
Modified image_rgb, fg_mask, depth_map
"""
fg_mask = fg_probability
if fg_mask is not None and self.mask_threshold > 0.0:
# threshold masks
warnings.warn("Thresholding masks!")
fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask)
if self.mask_images and fg_mask is not None and image_rgb is not None:
# mask the image
warnings.warn("Masking images!")
image_rgb = image_utils.mask_background(
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color)
)
if self.mask_depths and fg_mask is not None and depth_map is not None:
# mask the depths
assert (
self.mask_threshold > 0.0
), "Depths should be masked only with thresholded masks"
warnings.warn("Masking depths!")
depth_map = depth_map * fg_mask
return image_rgb, fg_mask, depth_map
def _get_view_metrics(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
metrics = self.view_metrics(
image_sampling_grid=xys,
images_pred=raymarched.features,
images=image_rgb,
depths_pred=raymarched.depths,
depths=depth_map,
masks_pred=raymarched.masks,
masks=fg_probability,
masks_crop=mask_crop,
keys_prefix=keys_prefix,
**raymarched.aux,
)
if raymarched.prev_stage:
metrics.update(
self._get_view_metrics(
raymarched.prev_stage,
xys,
image_rgb,
depth_map,
fg_probability,
mask_crop,
keys_prefix=(keys_prefix + "prev_stage_"),
)
)
return metrics
@torch.no_grad()
def _rasterize_mc_samples(
self,
xys: torch.Tensor,
features: torch.Tensor,
depth: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterizes Monte-Carlo features back onto the image.
Args:
xys: B x ... x 2 2D point locations in PyTorch3D NDC convention
features: B x ... x C tensor containing per-point rendered features.
depth: B x ... x 1 tensor containing per-point rendered depth.
"""
ba = xys.shape[0]
# Flatten the features and xy locations.
features_depth_ras = torch.cat(
(
features.reshape(ba, -1, features.shape[-1]),
depth.reshape(ba, -1, 1),
),
dim=-1,
)
xys_ras = xys.reshape(ba, -1, 2)
if masks is not None:
masks_ras = masks.reshape(ba, -1, 1)
else:
masks_ras = None
if min(self.render_image_height, self.render_image_width) <= 0:
raise ValueError(
"Need to specify a positive"
" self.render_image_height and self.render_image_width"
" for MC rasterisation."
)
# Estimate the rasterization point radius so that we approximately fill
# the whole image given the number of rasterized points.
pt_radius = 2.0 * math.sqrt(xys.shape[1])
# Rasterize the samples.
features_depth_render, masks_render = rasterize_mc_samples(
xys_ras,
features_depth_ras,
(self.render_image_height, self.render_image_width),
radius=pt_radius,
masks=masks_ras,
)
images_render = features_depth_render[:, :-1]
depths_render = features_depth_render[:, -1:]
return images_render, depths_render, masks_render
def _apply_chunked(func, chunk_generator, tensor_collator):
"""
Helper function to apply a function on a sequence of
chunked inputs yielded by a generator and collate
the result.
"""
processed_chunks = [
func(*chunk_args, **chunk_kwargs)
for chunk_args, chunk_kwargs in chunk_generator
]
return cat_dataclass(processed_chunks, tensor_collator)
def _tensor_collator(batch, new_dims) -> torch.Tensor:
"""
Helper function to reshape the batch to the desired shape
"""
return torch.cat(batch, dim=1).reshape(*new_dims, -1)
def _chunk_generator(
chunk_size: int,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
tqdm_trigger_threshold: int,
*args,
**kwargs,
):
"""
Helper function which yields chunks of rays from the
input ray_bundle, to be used when the number of rays is
large and will not fit in memory for rendering.
"""
(
batch_size,
*spatial_dim,
n_pts_per_ray,
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
raise ValueError(
f"chunk_size_grid ({chunk_size}) should be divisible "
f"by n_pts_per_ray ({n_pts_per_ray})"
)
n_rays = math.prod(spatial_dim)
# special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks)
iter = range(0, n_rays, chunk_size_in_rays)
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)
for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = RayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx
],
lengths=ray_bundle.lengths.reshape(
batch_size, math.prod(spatial_dim), n_pts_per_ray
)[:, start_idx:end_idx],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
)
extra_args = kwargs.copy()
if object_mask is not None:
extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[
:, start_idx:end_idx
]
yield [ray_bundle_chunk, *args], extra_args
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
raise NotImplementedError()
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True if this implicit function allows
multiple passes.
"""
return False
@staticmethod
def requires_pooling_without_aggregation() -> bool:
"""
Returns True if this implicit function needs
pooling without aggregation.
"""
return False
def on_bind_args(self) -> None:
"""
Called when the custom args are fixed in the main model forward pass.
"""
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment