Unverified Commit cdd2142d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

implicitron v0 (#1133)


Co-authored-by: default avatarJeremy Francis Reizenstein <bottler@users.noreply.github.com>
parent 0e377c68
defaults:
- repro_singleseq_base.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05
defaults:
- repro_singleseq_srn.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0
defaults:
- repro_singleseq_wce_base
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05
defaults:
- repro_singleseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0
defaults:
- repro_singleseq_base
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
This diff is collapsed.
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
import math
import os
import random
import sys
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.implicitron.tools.video_writer import VideoWriter
from pytorch3d.implicitron.tools.vis_utils import (
get_visdom_connection,
make_depth_image,
)
from tqdm import tqdm
def render_sequence(
dataset: ImplicitronDataset,
sequence_name: str,
model: torch.nn.Module,
video_path,
n_eval_cameras=40,
fps=20,
max_angle=2 * math.pi,
trajectory_type="circular_lsq_fit",
trajectory_scale=1.1,
scene_center=(0.0, 0.0, 0.0),
up=(0.0, -1.0, 0.0),
traj_offset=0.0,
n_source_views=9,
viz_env="debug",
visdom_show_preds=False,
visdom_server="http://127.0.0.1",
visdom_port=8097,
num_workers=10,
seed=None,
video_resize=None,
):
if seed is None:
seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name]
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
print(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
test_cameras = generate_eval_video_cameras(
train_cameras,
time=time,
n_eval_cams=n_eval_cameras,
trajectory_type=trajectory_type,
trajectory_scale=trajectory_scale,
scene_center=scene_center,
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
)
# sample the source views reproducibly
with torch.random.fork_rng():
torch.manual_seed(seed)
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
# add the first dummy view that will get replaced with the target camera
source_views_i = Fu.pad(source_views_i, [1, 0])
source_views = [seq_idx[i] for i in source_views_i.tolist()]
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
preds_total = []
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
# set the first batch camera to the target camera
for k in ("R", "T", "focal_length", "principal_point"):
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
# Move to cuda
net_input = batch.cuda()
with torch.no_grad():
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
# Render the predictions to images
rendered_pred = images_from_preds(preds)
preds_total.append(rendered_pred)
# show the preds every 5% of the export iterations
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
)
print(f"Exporting videos for sequence {sequence_name} ...")
generate_prediction_videos(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
fps=fps,
video_path=video_path,
resize=video_resize,
)
def _load_whole_dataset(dataset, idx, num_workers=10):
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
num_workers=num_workers,
shuffle=False,
collate_fn=FrameData.collate,
)
return next(iter(load_all_dataloader))
def images_from_preds(preds):
imout = {}
for k in (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
"_all_source_images",
):
if k == "_all_source_images" and "image_rgb" in preds:
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
v = _stack_images(src_ims, None)[None]
else:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
mask_resize = Fu.interpolate(
preds["masks_render"],
size=preds[k].shape[2:],
mode="nearest",
)
v = make_depth_image(preds[k], mask_resize)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
imout[k] = v.detach().cpu()
return imout
def _stack_images(ims, size):
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
n_add = H * W - ba
if n_add > 0:
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
ims = ims.view(H, W, *ims.shape[1:])
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
if size is not None:
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
return cated.clamp(0.0, 1.0)
def show_predictions(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
n_samples=10,
one_image_width=200,
):
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
pred_all = []
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples = min(n_samples, len(preds))
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
for predi in pred_idx:
# Make the concatentation for the same camera vertically
pred_all.append(
torch.cat(
[
torch.nn.functional.interpolate(
preds[predi][k].cpu(),
scale_factor=one_image_width / preds[predi][k].shape[3],
mode="bilinear",
).clamp(0.0, 1.0)
for k in predicted_keys
],
dim=2,
)
)
# Concatenate the images horizontally
pred_all_cat = torch.cat(pred_all, dim=3)[0]
viz.image(
pred_all_cat,
win="show_predictions",
env=viz_env,
opts={"title": f"pred_{sequence_name}"},
)
def generate_prediction_videos(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
fps=20,
video_path="/tmp/video",
resize=None,
):
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
assert isinstance(preds, list)
# make sure the target video directory exists
os.makedirs(os.path.dirname(video_path), exist_ok=True)
# init a video writer for each predicted key
vws = {}
for k in predicted_keys:
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
for rendered_pred in tqdm(preds):
for k in predicted_keys:
vws[k].write_frame(
rendered_pred[k][0].detach().cpu().numpy(),
resize=resize,
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
print(f"Generated {vws[k].out_path}.")
viz.video(
videofile=vws[k].out_path,
env=viz_env,
win=k, # we reuse the same window otherwise visdom dies
opts={"title": sequence_name + " " + k},
)
def export_scenes(
exp_dir: str = "",
restrict_sequence_name: Optional[str] = None,
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
visdom_port=8097,
visdom_show_preds: bool = False,
visdom_env: Optional[str] = None,
gpu_idx: int = 0,
):
# In case an output directory is specified use it. If no output_directory
# is specified create a vis folder inside the experiment directory
if output_directory is None:
output_directory = os.path.join(exp_dir, "vis")
else:
output_directory = output_directory
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Set the random seeds
torch.manual_seed(0)
np.random.seed(0)
# Get the config from the experiment_directory,
# and overwrite relevant fields
config = _get_config_from_experiment_directory(exp_dir)
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
# Load the previously trained model
model, _, _ = init_model(config, force_load=True, load_model_only=True)
model.cuda()
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys():
with torch.no_grad():
render_sequence(
dataset,
sequence_name,
model,
video_path="{}/video".format(output_directory),
n_source_views=n_source_views,
visdom_show_preds=visdom_show_preds,
n_eval_cameras=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
viz_env=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
video_resize=video_size,
)
def _get_config_from_experiment_directory(experiment_directory):
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
return config
def main(argv):
# automatically parses arguments of export_scenes
cfg = OmegaConf.create(get_default_args(export_scenes))
cfg.update(OmegaConf.from_cli())
with torch.no_grad():
export_scenes(**cfg)
if __name__ == "__main__":
main(sys.argv)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Sequence
import torch
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler
def dataloader_zoo(
datasets: Dict[str, ImplicitronDatasetBase],
dataset_name: str = "co3d_singlesequence",
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,),
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dict[str, torch.utils.data.DataLoader]:
"""
Returns a set of dataloaders for a given set of datasets.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")
dataloaders = {}
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
for dataset_set, dataset in datasets.items():
num_samples = {
"train": dataset_len,
"val": dataset_len_val,
"test": None,
}[dataset_set]
if dataset_set == "test":
batch_sampler = dataset.get_eval_batches()
else:
assert num_samples is not None
num_samples = len(dataset) if num_samples <= 0 else num_samples
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=num_samples,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
dataloaders[dataset_set] = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=FrameData.collate,
)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
return dataloaders
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from typing import Any, Dict, List, Optional, Sequence
from iopath.common.file_io import PathManager
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
from .utils import (
DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST,
DATASET_TYPE_TRAIN,
DATASET_TYPE_UNKNOWN,
)
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
"default": {
"box_crop": True,
"box_crop_context": 0.3,
"image_width": 800,
"image_height": 800,
"remove_empty_masks": True,
}
}
# fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv",
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
"umbrella", "wineglass", "laptop",
"hotdog", "stopsign", "frisbee", "baseballglove",
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
"handbag", "chair", "keyboard", "car", "motorcycle",
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
"toaster", "mouse", "toytrain", "book", "toytruck",
"orange", "broccoli", "plant", "teddybear",
"suitcase", "bench", "ball", "cake",
"vase", "hydrant", "apple", "donut",
]))
# fmt: on
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
category: str = "DEFAULT",
limit_to: int = -1,
limit_sequences_to: int = -1,
n_frames_per_sequence: int = -1,
test_on_train: bool = False,
load_point_clouds: bool = False,
mask_images: bool = False,
mask_depths: bool = False,
restrict_sequence_name: Sequence[str] = (),
test_restrict_sequence_id: int = -1,
assert_single_seq: bool = False,
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Dict[str, ImplicitronDatasetBase]:
"""
Generates the training / validation and testing dataset objects.
Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence.
test_on_train: Construct validation and test datasets from
the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'.
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call.
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
datasets = {}
# TODO:
# - implement loading multiple categories
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
dataset_name,
test_on_train,
only_test_set,
)
# load the evaluation batches
task = dataset_name.split("_")[-1]
batch_indices_path = os.path.join(
dataset_root,
category,
f"eval_batches_{task}.json",
)
if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError("Please specify a correct dataset_root folder.")
with open(batch_indices_path, "r") as f:
eval_batch_index = json.load(f)
if task == "singlesequence":
assert (
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
), (
"Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'"
+ " training and evaluation."
)
assert len(restrict_sequence_name) == 0, (
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
# a sort-stable set() equivalent:
eval_batches_sequence_names = list(
{b[0][0]: None for b in eval_batch_index}.keys()
)
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
eval_batch_index = [
b for b in eval_batch_index if b[0][0] == eval_sequence_name
]
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
for dataset, subsets in set_names_mapping.items():
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
assert os.path.isfile(frame_file)
sequence_file = os.path.join(
dataset_root, category, "sequence_annotations.jgz"
)
assert os.path.isfile(sequence_file)
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
assert os.path.isfile(subset_lists_file)
# TODO: maybe directly in param list
params = {
**copy.deepcopy(aux_dataset_kwargs),
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"n_frames_per_sequence": n_frames_per_sequence
if dataset == "train"
else -1,
"subsets": subsets,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"pick_sequence": restrict_sequence_name,
"path_manager": path_manager,
}
datasets[dataset] = ImplicitronDataset(**params)
if dataset == "test":
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
datasets[dataset].eval_batches = datasets[
dataset
].seq_frame_index_to_dataset_index(eval_batch_index)
if assert_single_seq:
# check theres only one sequence in all datasets
assert (
len(
{
e["frame_annotation"].sequence_name
for dset in datasets.values()
for e in dset.frame_annots
}
)
<= 1
), "Multiple sequences loaded but expected one"
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if test_on_train:
datasets["val"] = datasets["train"]
datasets["test"] = datasets["train"]
return datasets
def _get_co3d_set_names_mapping(
dataset_name: str,
test_on_train: bool,
only_test: bool,
) -> Dict[str, List[str]]:
"""
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
to the names of the corresponding subsets in the CO3D dataset
("test_known"/"test_unseen"/"train_known"/"train_unseen").
"""
single_seq = dataset_name == "co3d_singlesequence"
if only_test:
set_names_mapping = {}
else:
set_names_mapping = {
"train": [
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
+ "_"
+ DATASET_TYPE_KNOWN
]
}
if not test_on_train:
prefixes = [DATASET_TYPE_TEST]
if not single_seq:
prefixes.append(DATASET_TYPE_TRAIN)
set_names_mapping.update(
{
dset: [
p + "_" + t
for p in prefixes
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
]
for dset in ["val", "test"]
}
)
return set_names_mapping
This diff is collapsed.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
from .implicitron_dataset import ImplicitronDatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config?
class SceneBatchSampler(Sampler[List[int]]):
"""
A class for sampling training batches with a controlled composition
of sequences.
"""
dataset: ImplicitronDatasetBase
batch_size: int
num_batches: int
# the sampler first samples a random element k from this list and then
# takes k random frames per sequence
images_per_seq_options: Sequence[int]
# if True, will sample a contiguous interval of frames in the sequence
# it first finds the connected segments within the sequence of sufficient length,
# then samples a random pivot element among them and ideally uses it as a middle
# of the temporal window, shifting the borders where necessary.
# This strategy mitigates the bias against shorter segments and their boundaries.
sample_consecutive_frames: bool = False
# if a number > 0, then used to define the maximum difference in frame_number
# of neighbouring frames when forming connected segments; otherwise the whole
# sequence is considered a segment regardless of frame numbers
consecutive_frames_max_gap: int = 0
# same but for timestamps if they are available
consecutive_frames_max_gap_seconds: float = 0.1
seq_names: List[str] = field(init=False)
def __post_init__(self) -> None:
if self.batch_size <= 0:
raise ValueError(
"batch_size should be a positive integral value, "
f"but got batch_size={self.batch_size}"
)
if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys())
def __len__(self) -> int:
return self.num_batches
def __iter__(self) -> Iterator[List[int]]:
for batch_idx in range(len(self)):
batch = self._sample_batch(batch_idx)
yield batch
def _sample_batch(self, batch_idx) -> List[int]:
n_per_seq = np.random.choice(self.images_per_seq_options)
n_seqs = -(-self.batch_size // n_per_seq) # round up
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
if self.sample_consecutive_frames:
frame_idx = []
for seq in chosen_seq:
segment_index = self._build_segment_index(
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq:
frame_idx.append(segment)
else:
start = np.clip(idx - n_per_seq // 2, 0, len(segment) - n_per_seq)
frame_idx.append(segment[start : start + n_per_seq])
else:
frame_idx = [
_capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
)
for seq in chosen_seq
]
frame_idx = np.concatenate(frame_idx)[: self.batch_size].tolist()
if len(frame_idx) < self.batch_size:
warnings.warn(
"Batch size smaller than self.batch_size!"
+ " (This is fine for experiments with a single scene and viewpooling)"
)
return frame_idx
def _build_segment_index(
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
belongs to index is the frame's index within that segment.
Segment references are repeated but the memory is shared.
"""
if (
self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0
):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
seq_frame_indices, self.dataset
)
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size)
if not segments:
raise AssertionError("Empty segments after culling")
else:
segments = [seq_frame_indices]
# build an index of segment for random selection of a pivot frame
segment_index = [
(segment, i) for segment in segments for i in range(len(segment))
]
return segment_index
def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]]
) -> List[List[int]]:
if (
self.consecutive_frames_max_gap <= 0
and self.consecutive_frames_max_gap_seconds <= 0.0
):
raise AssertionError("This function is only needed for non-trivial max_gap")
segments = []
last_no = -self.consecutive_frames_max_gap - 1 # will trigger a new segment
last_ts = -self.consecutive_frames_max_gap_seconds - 1.0
for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no:
raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given"
)
if (
no - last_no > self.consecutive_frames_max_gap > 0
or ts - last_ts > self.consecutive_frames_max_gap_seconds > 0.0
): # new group
segments.append([idx])
else:
segments[-1].append(idx)
last_no = no
last_ts = ts
return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths)
if max_len < min_size:
return [longest_segment]
return [segment for segment in segments if len(segment) >= min_size]
def _capped_random_choice(x, size, replace: bool = True):
"""
if replace==True
randomly chooses from x `size` elements without replacement if len(x)>size
else allows replacement and selects `size` elements again.
if replace==False
randomly chooses from x `min(len(x), size)` elements without replacement
"""
len_x = x if isinstance(x, int) else len(x)
if replace:
return np.random.choice(x, size=size, replace=len_x < size)
else:
return np.random.choice(x, size=min(size, len_x), replace=False)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import gzip
import json
import sys
from dataclasses import MISSING, Field, dataclass
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
import numpy as np
_X = TypeVar("_X")
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls):
return getattr(cls, "__origin__", None)
def get_args(cls):
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")
TF3 = Tuple[float, float, float]
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int] # TODO: rename size_hw?
@dataclass
class DepthAnnotation:
# path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment`
path: str
# a factor to convert png values to actual depth: `depth = png * scale_adjustment`
scale_adjustment: float
# path to png file, relative w.r.t. dataset_root, storing binary `depth` mask
mask_path: Optional[str]
@dataclass
class MaskAnnotation:
# path to png file storing (Prob(fg | pixel) * 255)
path: str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass: Optional[float] = None
@dataclass
class ViewpointAnnotation:
# In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T
R: Tuple[TF3, TF3, TF3]
T: TF3
focal_length: Tuple[float, float]
principal_point: Tuple[float, float]
intrinsics_format: str = "ndc_norm_image_bounds"
# Defines the co-ordinate system where focal_length and principal_point live.
# Possible values: ndc_isotropic | ndc_norm_image_bounds (default)
# ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries
# correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ
# ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has
# the range [-1, 1], and the longer one has the range [-s, s]; s >= 1,
# where s is the aspect ratio. The scale is same along x and y.
@dataclass
class FrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
depth: Optional[DepthAnnotation] = None
mask: Optional[MaskAnnotation] = None
viewpoint: Optional[ViewpointAnnotation] = None
@dataclass
class PointCloudAnnotation:
# path to ply file with points only, relative w.r.t. dataset_root
path: str
# the bigger the better
quality_score: float
n_points: Optional[int]
@dataclass
class VideoAnnotation:
# path to the original video file, relative w.r.t. dataset_root
path: str
# length of the video in seconds
length: float
@dataclass
class SequenceAnnotation:
sequence_name: str
category: str
video: Optional[VideoAnnotation] = None
point_cloud: Optional[PointCloudAnnotation] = None
# the bigger the better
viewpoint_quality_score: Optional[float] = None
def dump_dataclass(obj: Any, f: IO, binary: bool = False) -> None:
"""
Args:
f: Either a path to a file, or a file opened for writing.
obj: A @dataclass or collection hierarchy including dataclasses.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
f.write(json.dumps(_asdict_rec(obj)).encode("utf8"))
else:
json.dump(_asdict_rec(obj), f)
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
if isinstance(asdict, list):
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
else:
res = _dataclass_from_dict(asdict, cls)
return res
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
elif any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
# otherwise, we dispatch by the type of the provided annotation to convert to
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls._field_types.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp)
for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp)
for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
vals = np.split(list(all_vals_res), indices[:-1])
return [cls(zip(*k, v)) for k, v in zip(keys, vals)]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
def _dataclass_from_dict(d, typeannot):
cls = get_origin(typeannot) or typeannot
if d is None:
return d
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
types = cls._field_types.values()
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
elif issubclass(cls, (list, tuple)):
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(d)
return cls(_dataclass_from_dict(v, tp) for v, tp in zip(d, types))
elif issubclass(cls, dict):
key_t, val_t = get_args(typeannot)
return cls(
(_dataclass_from_dict(k, key_t), _dataclass_from_dict(v, val_t))
for k, v in d.items()
)
elif not dataclasses.is_dataclass(typeannot):
return d
assert dataclasses.is_dataclass(cls)
fieldtypes = {f.name: _unwrap_type(f.type) for f in dataclasses.fields(typeannot)}
return cls(**{k: _dataclass_from_dict(v, fieldtypes[k]) for k, v in d.items()})
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _asdict_rec(obj):
return dataclasses._asdict_inner(obj, dict)
def dump_dataclass_jgzip(outfile: str, obj: Any) -> None:
"""
Dumps obj to a gzipped json outfile.
Args:
obj: A @dataclass or collection hiererchy including dataclasses.
outfile: The path to the output file.
"""
with gzip.GzipFile(outfile, "wb") as f:
dump_dataclass(obj, cast(IO, f), binary=True)
def load_dataclass_jgzip(outfile, cls):
"""
Loads a dataclass from a gzipped json outfile.
Args:
outfile: The path to the loaded file.
cls: The type annotation of the loaded dataclass.
Returns:
loaded_dataclass: The loaded dataclass.
"""
with gzip.GzipFile(outfile, "rb") as f:
return load_dataclass(cast(IO, f), cls, binary=True)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import torch
DATASET_TYPE_TRAIN = "train"
DATASET_TYPE_TEST = "test"
DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen"
def is_known_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a known frame.
"""
return torch.tensor(
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
dtype=torch.bool,
device=device,
)
def is_train_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a training frame.
"""
return torch.tensor(
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
dtype=torch.bool,
device=device,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, cast
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .implicitron_dataset import FrameData, ImplicitronDataset
def get_implicitron_sequence_pointcloud(
dataset: ImplicitronDataset,
sequence_name: Optional[str] = None,
mask_points: bool = True,
max_frames: int = -1,
num_workers: int = 0,
load_dataset_point_cloud: bool = False,
) -> Tuple[Pointclouds, FrameData]:
"""
Make a point cloud by sampling random points from each frame the dataset.
"""
if len(dataset) == 0:
raise ValueError("The dataset is empty.")
if not dataset.load_depths:
raise ValueError("The dataset has to load depths (dataset.load_depths=True).")
if mask_points and not dataset.load_masks:
raise ValueError(
"For mask_points=True, the dataset has to load masks"
+ " (dataset.load_masks=True)."
)
# setup the indices of frames loaded from the dataset db
sequence_entries = list(range(len(dataset)))
if sequence_name is not None:
sequence_entries = [
ei
for ei in sequence_entries
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]
if len(sequence_entries) == 0:
raise ValueError(
f'There are no dataset entries for sequence name "{sequence_name}".'
)
# subsample loaded frames if needed
if (max_frames > 0) and (len(sequence_entries) > max_frames):
sequence_entries = [
sequence_entries[i]
for i in torch.randperm(len(sequence_entries))[:max_frames].sort().values
]
# take only the part of the dataset corresponding to the sequence entries
sequence_dataset = torch.utils.data.Subset(dataset, sequence_entries)
# load the required part of the dataset
loader = torch.utils.data.DataLoader(
sequence_dataset,
batch_size=len(sequence_dataset),
shuffle=False,
num_workers=num_workers,
collate_fn=FrameData.collate,
)
frame_data = next(iter(loader)) # there's only one batch
# scene point cloud
if load_dataset_point_cloud:
if not dataset.load_point_clouds:
raise ValueError(
"For load_dataset_point_cloud=True, the dataset has to"
+ " load point clouds (dataset.load_point_clouds=True)."
)
point_cloud = frame_data.sequence_point_cloud
else:
point_cloud = get_rgbd_point_cloud(
frame_data.camera,
frame_data.image_rgb,
frame_data.depth_map,
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if frame_data.fg_probability is not None
else None,
mask_points=mask_points,
)
return point_cloud, frame_data
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import os
from typing import Optional, cast
import lpips
import torch
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,
eval_batch,
pretty_print_nvs_metrics,
summarize_nvs_eval_results,
)
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm
def main() -> None:
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for multisequence/singlesequence tasks for several categories.
The evaluation is conducted on the same data as in [1] and, hence, the results
are directly comparable to the numbers reported in [1].
References:
[1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone,
P. Labatut, D. Novotny:
Common Objects in 3D: Large-Scale Learning
and Evaluation of Real-life 3D Category Reconstruction
"""
task_results = {}
for task in ("singlesequence", "multisequence"):
task_results[task] = []
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
category_result = evaluate_dbir_for_category(
category, task=task, single_sequence_id=single_sequence_id
)
print("")
print(
f"Results for task={task}; category={category};"
+ (
f" sequence={single_sequence_id}:"
if single_sequence_id is not None
else ":"
)
)
pretty_print_nvs_metrics(category_result)
print("")
task_results[task].append(category_result)
_print_aggregate_results(task, task_results)
for task in task_results:
_print_aggregate_results(task, task_results)
def evaluate_dbir_for_category(
category: str = "apple",
bg_color: float = 0.0,
task: str = "singlesequence",
single_sequence_id: Optional[int] = None,
num_workers: int = 16,
):
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for a given task, category, and sequence (in case task=='singlesequence').
Args:
category: Object category.
bg_color: Background color of the renders.
task: Evaluation task. Either singlesequence or multisequence.
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
num_workers: The number of workers for the employed dataloaders.
Returns:
category_result: A dictionary of quantitative metrics.
"""
single_sequence_id = single_sequence_id if single_sequence_id is not None else -1
torch.manual_seed(42)
if task not in ["multisequence", "singlesequence"]:
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
datasets = dataset_zoo(
category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == "singlesequence",
dataset_name=f"co3d_{task}",
test_on_train=False,
load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id,
)
dataloaders = dataloader_zoo(
datasets,
dataset_name=f"co3d_{task}",
)
test_dataset = datasets["test"]
test_dataloader = dataloaders["test"]
if task == "singlesequence":
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
all_source_cameras = _get_all_source_cameras(
test_dataset, sequence_name, num_workers=num_workers
)
else:
all_source_cameras = None
image_size = cast(ImplicitronDataset, test_dataset).image_width
if image_size is None:
raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model
model = ModelDBIR(
image_size=image_size,
bg_color=bg_color,
max_points=int(1e5),
)
model.cuda()
# init the lpips model for eval
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.cuda()
per_batch_eval_results = []
print("Evaluating DBIR model ...")
for frame_data in tqdm(test_dataloader):
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
eval_batch(
frame_data,
nvs_prediction,
bg_color=bg_color,
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _print_aggregate_results(task, task_results) -> None:
"""
Prints the aggregate metrics for a given task.
"""
aggregate_task_result = aggregate_nvs_results(task_results[task])
print("")
print(f"Aggregate results for task={task}:")
pretty_print_nvs_metrics(aggregate_task_result)
print("")
def _get_all_source_cameras(
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx = dataset.seq_to_idx[sequence_name]
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
return source_cameras
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
raise NotImplementedError()
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True if this implicit function allows
multiple passes.
"""
return False
@staticmethod
def requires_pooling_without_aggregation() -> bool:
"""
Returns True if this implicit function needs
pooling without aggregation.
"""
return False
def on_bind_args(self) -> None:
"""
Called when the custom args are fixed in the main model forward pass.
"""
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment