Commit 813e941d authored by Emilien Garreau's avatar Emilien Garreau Committed by Facebook GitHub Bot
Browse files

Add the OverfitModel

Summary:
Introduces the OverfitModel for NeRF-style training with overfitting to one scene.
It is a specific case of GenericModel. It has been disentangle to ease usage.

## General modification

1. Modularize a minimum GenericModel to introduce OverfitModel
2. Introduce OverfitModel and ensure through unit testing that it behaves like GenericModel.

## Modularization

The following methods have been extracted from GenericModel to allow modularity with ManyViewModel:
- get_objective is now a call to weighted_sum_losses
- log_loss_weights
- prepare_inputs

The generic methods have been moved to an utils.py file.

Simplify the code to introduce OverfitModel.

Private methods like chunk_generator are now public and can now be used by ManyViewModel.

Reviewed By: shapovalov

Differential Revision: D43771992

fbshipit-source-id: 6102aeb21c7fdd56aa2ff9cd1dd23fd9fbf26315
parent 7d8b029a
...@@ -248,7 +248,7 @@ The main object for this trainer loop is `Experiment`. It has four top-level rep ...@@ -248,7 +248,7 @@ The main object for this trainer loop is `Experiment`. It has four top-level rep
* `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`. * `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`.
It constructs the data sets and dataloaders. It constructs the data sets and dataloaders.
* `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`. * `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`.
It constructs the model, which is usually an instance of implicitron's main `GenericModel` class, and can load its weights from a checkpoint. It constructs the model, which is usually an instance of `OverfitModel` (for NeRF-style training with overfitting to one scene) or `GenericModel` (that is able to generalize to multiple scenes by NeRFormer-style conditioning on other scene views), and can load its weights from a checkpoint.
* `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`. * `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`.
It constructs the optimizer and can load its weights from a checkpoint. It constructs the optimizer and can load its weights from a checkpoint.
* `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop. * `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop.
...@@ -292,6 +292,43 @@ model_GenericModel_args: GenericModel ...@@ -292,6 +292,43 @@ model_GenericModel_args: GenericModel
╘== ReductionFeatureAggregator ╘== ReductionFeatureAggregator
``` ```
Here is the class structure of OverfitModel:
```
model_OverfitModel_args: OverfitModel
└-- raysampler_*_args: RaySampler
╘== AdaptiveRaysampler
╘== NearFarRaysampler
└-- renderer_*_args: BaseRenderer
╘== MultiPassEmissionAbsorptionRenderer
╘== LSTMRenderer
╘== SignedDistanceFunctionRenderer
└-- ray_tracer_args: RayTracing
└-- ray_normal_coloring_network_args: RayNormalColoringNetwork
└-- implicit_function_*_args: ImplicitFunctionBase
╘== NeuralRadianceFieldImplicitFunction
╘== SRNImplicitFunction
└-- raymarch_function_args: SRNRaymarchFunction
└-- pixel_generator_args: SRNPixelGenerator
╘== SRNHyperNetImplicitFunction
└-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField
└-- coarse_implicit_function_*_args: ImplicitFunctionBase
╘== NeuralRadianceFieldImplicitFunction
╘== SRNImplicitFunction
└-- raymarch_function_args: SRNRaymarchFunction
└-- pixel_generator_args: SRNPixelGenerator
╘== SRNHyperNetImplicitFunction
└-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField
```
OverfitModel has been introduced to create a simple class to disantagle Nerfs which the overfit pattern
from the GenericModel.
Please look at the annotations of the respective classes or functions for the lists of hyperparameters. Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
`tests/experiment.yaml` shows every possible option if you have no user-defined classes. `tests/experiment.yaml` shows every possible option if you have no user-defined classes.
......
defaults:
- default_config
- _self_
exp_dir: ./data/exps/overfit_base/
training_loop_ImplicitronTrainingLoop_args:
visdom_port: 8097
visualize_interval: 0
max_epochs: 1000
data_source_ImplicitronDataSource_args:
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
dataset_JsonIndexDataset_args:
load_point_clouds: false
mask_depths: false
mask_images: false
model_factory_ImplicitronModelFactory_args:
model_class_type: "OverfitModel"
model_OverfitModel_args:
loss_weights:
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 1.0
loss_autodecoder_norm: 0.01
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
output_rasterized_mc: false
chunk_size_grid: 102400
render_image_height: 400
render_image_width: 400
share_implicit_function_across_passes: false
implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_xyz: 256
n_hidden_neurons_dir: 128
n_layers_xyz: 8
append_xyz:
- 5
coarse_implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_xyz: 256
n_hidden_neurons_dir: 128
n_layers_xyz: 8
append_xyz:
- 5
raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024
scene_extent: 8.0
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true
density_noise_std_train: 1.0
optimizer_factory_ImplicitronOptimizerFactory_args:
breed: Adam
weight_decay: 0.0
lr_policy: MultiStepLR
multistep_lr_milestones: []
lr: 0.0005
gamma: 0.1
momentum: 0.9
betas:
- 0.9
- 0.999
defaults:
- overfit_base
- _self_
data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0
test_on_train: false
model_factory_ImplicitronModelFactory_args:
model_class_type: "OverfitModel"
model_OverfitModel_args:
render_image_height: 800
render_image_width: 800
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_eikonal
- loss_prev_stage_rgb_psnr
- loss_mask_bce
- loss_prev_stage_mask_bce
- loss_rgb_mse
- loss_prev_stage_rgb_mse
- loss_depth_abs
- loss_depth_abs_fg
- loss_kl
- loss_mask_neg_iou
- objective
- epoch
- sec/it
optimizer_factory_ImplicitronOptimizerFactory_args:
lr: 0.0005
multistep_lr_milestones:
- 200
- 300
training_loop_ImplicitronTrainingLoop_args:
max_epochs: 400
defaults:
- overfit_singleseq_base
- _self_
exp_dir: "./data/overfit_nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
dataset_length_train: 100
dataset_map_provider_class_type: BlenderDatasetMapProvider
dataset_map_provider_BlenderDatasetMapProvider_args:
base_dir: ${oc.env:BLENDER_DATASET_ROOT}/${oc.env:BLENDER_SINGLESEQ_CLASS}
n_known_frames_for_test: null
object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS}
path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args:
silence_logs: true
model_factory_ImplicitronModelFactory_args:
model_class_type: "OverfitModel"
model_OverfitModel_args:
mask_images: false
raysampler_class_type: AdaptiveRaySampler
raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 4096
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 2.0
scene_center:
- 0.0
- 0.0
- 0.0
renderer_MultiPassEmissionAbsorptionRenderer_args:
density_noise_std_train: 0.0
n_pts_per_ray_fine_training: 128
n_pts_per_ray_fine_evaluation: 128
raymarcher_EmissionAbsorptionRaymarcher_args:
blend_output: false
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.00
optimizer_factory_ImplicitronOptimizerFactory_args:
exponential_lr_step_size: 3001
lr_policy: LinearExponential
linear_exponential_lr_milestone: 200
training_loop_ImplicitronTrainingLoop_args:
max_epochs: 6000
metric_print_interval: 10
store_checkpoints_purge: 3
test_when_finished: true
validation_interval: 100
...@@ -59,7 +59,7 @@ from pytorch3d.implicitron.dataset.data_source import ( ...@@ -59,7 +59,7 @@ from pytorch3d.implicitron.dataset.data_source import (
DataSourceBase, DataSourceBase,
ImplicitronDataSource, ImplicitronDataSource,
) )
from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.renderer.multipass_ea import ( from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer, MultiPassEmissionAbsorptionRenderer,
......
...@@ -561,6 +561,623 @@ model_factory_ImplicitronModelFactory_args: ...@@ -561,6 +561,623 @@ model_factory_ImplicitronModelFactory_args:
use_xavier_init: true use_xavier_init: true
view_metrics_ViewMetrics_args: {} view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {}
model_OverfitModel_args:
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
mask_images: true
mask_depths: true
render_image_width: 400
render_image_height: 400
mask_threshold: 0.5
output_rasterized_mc: false
bg_color:
- 0.0
- 0.0
- 0.0
chunk_size_grid: 4096
render_features_dimensions: 3
tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
global_encoder_class_type: null
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: MultiPassEmissionAbsorptionRenderer
share_implicit_function_across_passes: false
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
coarse_implicit_function_class_type: null
view_metrics_class_type: ViewMetrics
regularization_metrics_class_type: RegularizationMetrics
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10
append_input: true
time_divisor: 1.0
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 1
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
scene_center:
- 0.0
- 0.0
- 0.0
raysampler_NearFarRaySampler_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
min_depth: 0.1
max_depth: 8.0
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
bg_color: null
verbose: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
raymarcher_class_type: EmissionAbsorptionRaymarcher
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
stratified_sampling_coarse_training: true
stratified_sampling_coarse_evaluation: false
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
replicate_last_interval: false
background_opacity: 0.0
density_relu: true
blend_output: false
raymarcher_EmissionAbsorptionRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
replicate_last_interval: false
background_opacity: 10000000000.0
density_relu: true
blend_output: false
renderer_SignedDistanceFunctionRenderer_args:
ray_normal_coloring_network_args:
feature_vector_size: 3
mode: idr
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
weight_norm: true
n_harmonic_functions_dir: 0
pooled_feature_dim: 0
bg_color:
- 0.0
soft_mask_alpha: 50.0
ray_tracer_args:
sdf_threshold: 5.0e-05
line_search_step: 0.5
line_step_iters: 1
sphere_tracing_iters: 10
n_steps: 100
n_secant_steps: 8
implicit_function_IdrFeatureField_args:
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
append_xyz:
- 1
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
implicit_function_SRNHyperNetImplicitFunction_args:
latent_dim_hypernet: 0
hypernet_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
n_hidden_units_hypernet: 256
n_layers_hypernet: 1
in_features: 3
out_features: 256
xyz_in_camera_coords: false
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_VoxelGridImplicitFunction_args:
harmonic_embedder_xyz_density_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_xyz_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_dir_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
decoder_density_class_type: MLPDecoder
decoder_color_class_type: MLPDecoder
use_multiple_streams: true
xyz_ray_dir_in_camera_coords: false
scaffold_calculating_epochs: []
scaffold_resolution:
- 128
- 128
- 128
scaffold_empty_space_threshold: 0.001
scaffold_occupancy_chunk_size: -1
scaffold_max_pool_kernel_size: 3
scaffold_filter_points: true
volume_cropping_epochs: []
voxel_grid_density_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
voxel_grid_color_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
decoder_density_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_density_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
decoder_color_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_color_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
coarse_implicit_function_IdrFeatureField_args:
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
coarse_implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
append_xyz:
- 1
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
coarse_implicit_function_SRNHyperNetImplicitFunction_args:
latent_dim_hypernet: 0
hypernet_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
n_hidden_units_hypernet: 256
n_layers_hypernet: 1
in_features: 3
out_features: 256
xyz_in_camera_coords: false
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
coarse_implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
coarse_implicit_function_VoxelGridImplicitFunction_args:
harmonic_embedder_xyz_density_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_xyz_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_dir_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
decoder_density_class_type: MLPDecoder
decoder_color_class_type: MLPDecoder
use_multiple_streams: true
xyz_ray_dir_in_camera_coords: false
scaffold_calculating_epochs: []
scaffold_resolution:
- 128
- 128
- 128
scaffold_empty_space_threshold: 0.001
scaffold_occupancy_chunk_size: -1
scaffold_max_pool_kernel_size: 3
scaffold_filter_points: true
volume_cropping_epochs: []
voxel_grid_density_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
voxel_grid_color_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
decoder_density_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_density_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
decoder_color_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_color_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {}
optimizer_factory_ImplicitronOptimizerFactory_args: optimizer_factory_ImplicitronOptimizerFactory_args:
betas: betas:
- 0.9 - 0.9
......
...@@ -141,7 +141,11 @@ class TestExperiment(unittest.TestCase): ...@@ -141,7 +141,11 @@ class TestExperiment(unittest.TestCase):
# Check that all the pre-prepared configs are valid. # Check that all the pre-prepared configs are valid.
config_files = [] config_files = []
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): for pattern in (
"repro_singleseq*.yaml",
"repro_multiseq*.yaml",
"overfit_singleseq*.yaml",
):
config_files.extend( config_files.extend(
[ [
f f
......
...@@ -3,3 +3,8 @@ ...@@ -3,3 +3,8 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# Allows to register the models
# see: pytorch3d.implicitron.tools.config.registry:register
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.overfit_model import OverfitModel
...@@ -8,11 +8,11 @@ from dataclasses import dataclass, field ...@@ -8,11 +8,11 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .renderer.base import EvaluationMode
@dataclass @dataclass
class ImplicitronRender: class ImplicitronRender:
......
...@@ -9,14 +9,11 @@ ...@@ -9,14 +9,11 @@
# which are part of implicitron. They ensure that the registry is prepopulated. # which are part of implicitron. They ensure that the registry is prepopulated.
import logging import logging
import warnings
from dataclasses import field from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch import torch
import tqdm
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.base_model import ( from pytorch3d.implicitron.models.base_model import (
ImplicitronModelBase, ImplicitronModelBase,
...@@ -33,11 +30,9 @@ from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( ...@@ -33,11 +30,9 @@ from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
) )
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
NeRFormerImplicitFunction, NeRFormerImplicitFunction,
NeuralRadianceFieldImplicitFunction,
) )
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction, SRNHyperNetImplicitFunction,
SRNImplicitFunction,
) )
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
VoxelGridImplicitFunction, VoxelGridImplicitFunction,
...@@ -63,8 +58,16 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase ...@@ -63,8 +58,16 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
SignedDistanceFunctionRenderer, SignedDistanceFunctionRenderer,
) )
from pytorch3d.implicitron.models.utils import (
apply_chunked,
chunk_generator,
log_loss_weights,
preprocess_input,
weighted_sum_losses,
)
from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler
from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools import vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
expand_args_fields, expand_args_fields,
registry, registry,
...@@ -72,7 +75,6 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -72,7 +75,6 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import utils as rend_utils from pytorch3d.renderer import utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
...@@ -323,7 +325,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -323,7 +325,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
self._implicit_functions = self._construct_implicit_functions() self._implicit_functions = self._construct_implicit_functions()
self.log_loss_weights() log_loss_weights(self.loss_weights, logger)
def forward( def forward(
self, self,
...@@ -367,8 +369,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -367,8 +369,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
preds: A dictionary containing all outputs of the forward pass including the preds: A dictionary containing all outputs of the forward pass including the
rendered images, depths, masks, losses and other metrics. rendered images, depths, masks, losses and other metrics.
""" """
image_rgb, fg_probability, depth_map = self._preprocess_input( image_rgb, fg_probability, depth_map = preprocess_input(
image_rgb, fg_probability, depth_map image_rgb,
fg_probability,
depth_map,
self.mask_images,
self.mask_depths,
self.mask_threshold,
self.bg_color,
) )
# Obtain the batch size from the camera as this is the only required input. # Obtain the batch size from the camera as this is the only required input.
...@@ -453,12 +461,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -453,12 +461,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
for func in self._implicit_functions: for func in self._implicit_functions:
func.bind_args(**custom_args) func.bind_args(**custom_args)
chunked_renderer_inputs = {} inputs_to_be_chunked = {}
if fg_probability is not None and self.renderer.requires_object_mask(): if fg_probability is not None and self.renderer.requires_object_mask():
sampled_fb_prob = rend_utils.ndc_grid_sample( sampled_fb_prob = rend_utils.ndc_grid_sample(
fg_probability[:n_targets], ray_bundle.xys, mode="nearest" fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
) )
chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5 inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5
# (5)-(6) Implicit function evaluation and Rendering # (5)-(6) Implicit function evaluation and Rendering
rendered = self._render( rendered = self._render(
...@@ -466,7 +474,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -466,7 +474,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
sampling_mode=sampling_mode, sampling_mode=sampling_mode,
evaluation_mode=evaluation_mode, evaluation_mode=evaluation_mode,
implicit_functions=self._implicit_functions, implicit_functions=self._implicit_functions,
chunked_inputs=chunked_renderer_inputs, inputs_to_be_chunked=inputs_to_be_chunked,
) )
# Unbind the custom arguments to prevent pytorch from storing # Unbind the custom arguments to prevent pytorch from storing
...@@ -530,30 +538,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -530,30 +538,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
raise AssertionError("Unreachable state") raise AssertionError("Unreachable state")
# (7) Compute losses # (7) Compute losses
# finally get the optimization objective using self.loss_weights
objective = self._get_objective(preds) objective = self._get_objective(preds)
if objective is not None: if objective is not None:
preds["objective"] = objective preds["objective"] = objective
return preds return preds
def _get_objective(self, preds) -> Optional[torch.Tensor]: def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]:
""" """
A helper function to compute the overall loss as the dot product A helper function to compute the overall loss as the dot product
of individual loss functions with the corresponding weights. of individual loss functions with the corresponding weights.
""" """
losses_weighted = [ return weighted_sum_losses(preds, self.loss_weights)
preds[k] * float(w)
for k, w in self.loss_weights.items()
if (k in preds and w != 0.0)
]
if len(losses_weighted) == 0:
warnings.warn("No main objective found.")
return None
loss = sum(losses_weighted)
assert torch.is_tensor(loss)
# pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
return loss
def visualize( def visualize(
self, self,
...@@ -585,7 +581,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -585,7 +581,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
self, self,
*, *,
ray_bundle: ImplicitronRayBundle, ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor], inputs_to_be_chunked: Dict[str, torch.Tensor],
sampling_mode: RenderSamplingMode, sampling_mode: RenderSamplingMode,
**kwargs, **kwargs,
) -> RendererOutput: ) -> RendererOutput:
...@@ -593,7 +589,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -593,7 +589,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
Args: Args:
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g.
SignedDistanceFunctionRenderer requires "object_mask", shape SignedDistanceFunctionRenderer requires "object_mask", shape
(B, 1, H, W), the silhouette of the object in the image. When (B, 1, H, W), the silhouette of the object in the image. When
chunking, they are passed to the renderer as shape chunking, they are passed to the renderer as shape
...@@ -605,30 +601,27 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -605,30 +601,27 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
An instance of RendererOutput An instance of RendererOutput
""" """
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0: if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
return _apply_chunked( return apply_chunked(
self.renderer, self.renderer,
_chunk_generator( chunk_generator(
self.chunk_size_grid, self.chunk_size_grid,
ray_bundle, ray_bundle,
chunked_inputs, inputs_to_be_chunked,
self.tqdm_trigger_threshold, self.tqdm_trigger_threshold,
**kwargs, **kwargs,
), ),
lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]), lambda batch: torch.cat(batch, dim=1).reshape(
*ray_bundle.lengths.shape[:-1], -1
),
) )
else: else:
# pyre-fixme[29]: `BaseRenderer` is not a function. # pyre-fixme[29]: `BaseRenderer` is not a function.
return self.renderer( return self.renderer(
ray_bundle=ray_bundle, ray_bundle=ray_bundle,
**chunked_inputs, **inputs_to_be_chunked,
**kwargs, **kwargs,
) )
def _get_global_encoder_encoding_dim(self) -> int:
if self.global_encoder is None:
return 0
return self.global_encoder.get_encoding_dim()
def _get_viewpooled_feature_dim(self) -> int: def _get_viewpooled_feature_dim(self) -> int:
if self.view_pooler is None: if self.view_pooler is None:
return 0 return 0
...@@ -720,30 +713,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -720,30 +713,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
function(s) are initialized. function(s) are initialized.
""" """
extra_args = {} extra_args = {}
global_encoder_dim = (
0 if self.global_encoder is None else self.global_encoder.get_encoding_dim()
)
viewpooled_feature_dim = self._get_viewpooled_feature_dim()
if self.implicit_function_class_type in ( if self.implicit_function_class_type in (
"NeuralRadianceFieldImplicitFunction", "NeuralRadianceFieldImplicitFunction",
"NeRFormerImplicitFunction", "NeRFormerImplicitFunction",
): ):
extra_args["latent_dim"] = ( extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
self._get_viewpooled_feature_dim()
+ self._get_global_encoder_encoding_dim()
)
extra_args["color_dim"] = self.render_features_dimensions extra_args["color_dim"] = self.render_features_dimensions
if self.implicit_function_class_type == "IdrFeatureField": if self.implicit_function_class_type == "IdrFeatureField":
extra_args["feature_vector_size"] = self.render_features_dimensions extra_args["feature_vector_size"] = self.render_features_dimensions
extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim() extra_args["encoding_dim"] = global_encoder_dim
if self.implicit_function_class_type == "SRNImplicitFunction": if self.implicit_function_class_type == "SRNImplicitFunction":
extra_args["latent_dim"] = ( extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
self._get_viewpooled_feature_dim()
+ self._get_global_encoder_encoding_dim()
)
# srn_hypernet preprocessing # srn_hypernet preprocessing
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction": if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
extra_args["latent_dim"] = self._get_viewpooled_feature_dim() extra_args["latent_dim"] = viewpooled_feature_dim
extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim() extra_args["latent_dim_hypernet"] = global_encoder_dim
# check that for srn, srn_hypernet, idr we have self.num_passes=1 # check that for srn, srn_hypernet, idr we have self.num_passes=1
implicit_function_type = registry.get( implicit_function_type = registry.get(
...@@ -770,147 +762,3 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -770,147 +762,3 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
for _ in range(self.num_passes) for _ in range(self.num_passes)
] ]
return torch.nn.ModuleList(implicit_functions_list) return torch.nn.ModuleList(implicit_functions_list)
def log_loss_weights(self) -> None:
"""
Print a table of the loss weights.
"""
loss_weights_message = (
"-------\nloss_weights:\n"
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
+ "-------"
)
logger.info(loss_weights_message)
def _preprocess_input(
self,
image_rgb: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Helper function to preprocess the input images and optional depth maps
to apply masking if required.
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
corresponding to the source viewpoints from which features will be extracted
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
of foreground masks with values in [0, 1].
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
Returns:
Modified image_rgb, fg_mask, depth_map
"""
if image_rgb is not None and image_rgb.ndim == 3:
# The FrameData object is used for both frames and batches of frames,
# and a user might get this error if those were confused.
# Perhaps a user has a FrameData `fd` representing a single frame and
# wrote something like `model(**fd)` instead of
# `model(**fd.collate([fd]))`.
raise ValueError(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated."
)
fg_mask = fg_probability
if fg_mask is not None and self.mask_threshold > 0.0:
# threshold masks
warnings.warn("Thresholding masks!")
fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask)
if self.mask_images and fg_mask is not None and image_rgb is not None:
# mask the image
warnings.warn("Masking images!")
image_rgb = image_utils.mask_background(
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color)
)
if self.mask_depths and fg_mask is not None and depth_map is not None:
# mask the depths
assert (
self.mask_threshold > 0.0
), "Depths should be masked only with thresholded masks"
warnings.warn("Masking depths!")
depth_map = depth_map * fg_mask
return image_rgb, fg_mask, depth_map
def _apply_chunked(func, chunk_generator, tensor_collator):
"""
Helper function to apply a function on a sequence of
chunked inputs yielded by a generator and collate
the result.
"""
processed_chunks = [
func(*chunk_args, **chunk_kwargs)
for chunk_args, chunk_kwargs in chunk_generator
]
return cat_dataclass(processed_chunks, tensor_collator)
def _tensor_collator(batch, new_dims) -> torch.Tensor:
"""
Helper function to reshape the batch to the desired shape
"""
return torch.cat(batch, dim=1).reshape(*new_dims, -1)
def _chunk_generator(
chunk_size: int,
ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor],
tqdm_trigger_threshold: int,
*args,
**kwargs,
):
"""
Helper function which yields chunks of rays from the
input ray_bundle, to be used when the number of rays is
large and will not fit in memory for rendering.
"""
(
batch_size,
*spatial_dim,
n_pts_per_ray,
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
raise ValueError(
f"chunk_size_grid ({chunk_size}) should be divisible "
f"by n_pts_per_ray ({n_pts_per_ray})"
)
n_rays = prod(spatial_dim)
# special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks)
iter = range(0, n_rays, chunk_size_in_rays)
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)
def _safe_slice(
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
) -> Any:
return tensor[start_idx:end_idx] if tensor is not None else None
for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = ImplicitronRayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx
],
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
:, start_idx:end_idx
],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
)
extra_args = kwargs.copy()
for k, v in chunked_inputs.items():
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
yield [ray_bundle_chunk, *args], extra_args
This diff is collapsed.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated.
import warnings
from logging import Logger
from typing import Any, Dict, Optional, Tuple
import torch
import tqdm
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools import image_utils
from pytorch3d.implicitron.tools.utils import cat_dataclass
def preprocess_input(
image_rgb: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
mask_images: bool,
mask_depths: bool,
mask_threshold: float,
bg_color: Tuple[float, float, float],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Helper function to preprocess the input images and optional depth maps
to apply masking if required.
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
corresponding to the source viewpoints from which features will be extracted
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
of foreground masks with values in [0, 1].
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
mask_images: Whether or not to mask the RGB image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
mask_depths: Whether or not to mask the depth image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
mask_threshold: If greater than 0.0, the foreground mask is
thresholded by this value before being applied to the RGB/Depth images
bg_color: RGB values for setting the background color of input image
if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
way to determine the background color of its output, unrelated to this.
Returns:
Modified image_rgb, fg_mask, depth_map
"""
if image_rgb is not None and image_rgb.ndim == 3:
# The FrameData object is used for both frames and batches of frames,
# and a user might get this error if those were confused.
# Perhaps a user has a FrameData `fd` representing a single frame and
# wrote something like `model(**fd)` instead of
# `model(**fd.collate([fd]))`.
raise ValueError(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated."
)
fg_mask = fg_probability
if fg_mask is not None and mask_threshold > 0.0:
# threshold masks
warnings.warn("Thresholding masks!")
fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask)
if mask_images and fg_mask is not None and image_rgb is not None:
# mask the image
warnings.warn("Masking images!")
image_rgb = image_utils.mask_background(
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color)
)
if mask_depths and fg_mask is not None and depth_map is not None:
# mask the depths
assert (
mask_threshold > 0.0
), "Depths should be masked only with thresholded masks"
warnings.warn("Masking depths!")
depth_map = depth_map * fg_mask
return image_rgb, fg_mask, depth_map
def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None:
"""
Print a table of the loss weights.
"""
loss_weights_message = (
"-------\nloss_weights:\n"
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items())
+ "-------"
)
logger.info(loss_weights_message)
def weighted_sum_losses(
preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float]
) -> Optional[torch.Tensor]:
"""
A helper function to compute the overall loss as the dot product
of individual loss functions with the corresponding weights.
"""
losses_weighted = [
preds[k] * float(w)
for k, w in loss_weights.items()
if (k in preds and w != 0.0)
]
if len(losses_weighted) == 0:
warnings.warn("No main objective found.")
return None
loss = sum(losses_weighted)
assert torch.is_tensor(loss)
# pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
return loss
def apply_chunked(func, chunk_generator, tensor_collator):
"""
Helper function to apply a function on a sequence of
chunked inputs yielded by a generator and collate
the result.
"""
processed_chunks = [
func(*chunk_args, **chunk_kwargs)
for chunk_args, chunk_kwargs in chunk_generator
]
return cat_dataclass(processed_chunks, tensor_collator)
def chunk_generator(
chunk_size: int,
ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor],
tqdm_trigger_threshold: int,
*args,
**kwargs,
):
"""
Helper function which yields chunks of rays from the
input ray_bundle, to be used when the number of rays is
large and will not fit in memory for rendering.
"""
(
batch_size,
*spatial_dim,
n_pts_per_ray,
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
raise ValueError(
f"chunk_size_grid ({chunk_size}) should be divisible "
f"by n_pts_per_ray ({n_pts_per_ray})"
)
n_rays = prod(spatial_dim)
# special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks)
iter = range(0, n_rays, chunk_size_in_rays)
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)
def _safe_slice(
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
) -> Any:
return tensor[start_idx:end_idx] if tensor is not None else None
for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = ImplicitronRayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx
],
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
:, start_idx:end_idx
],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
)
extra_args = kwargs.copy()
for k, v in chunked_inputs.items():
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
yield [ray_bundle_chunk, *args], extra_args
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import pytorch3d
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.ops import packed_to_padded from pytorch3d.ops import packed_to_padded
from pytorch3d.renderer import PerspectiveCameras from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
...@@ -18,7 +19,7 @@ from .point_cloud_utils import render_point_cloud_pytorch3d ...@@ -18,7 +19,7 @@ from .point_cloud_utils import render_point_cloud_pytorch3d
@torch.no_grad() @torch.no_grad()
def rasterize_sparse_ray_bundle( def rasterize_sparse_ray_bundle(
ray_bundle: ImplicitronRayBundle, ray_bundle: "pytorch3d.implicitron.models.renderer.base.ImplicitronRayBundle",
features: torch.Tensor, features: torch.Tensor,
image_size_hw: Tuple[int, int], image_size_hw: Tuple[int, int],
depth: torch.Tensor, depth: torch.Tensor,
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Any, Dict
from unittest.mock import patch
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.overfit_model import OverfitModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
DEVICE = torch.device("cuda:0")
def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]:
R, T = look_at_view_transform(azim=torch.rand(N) * 360)
return {
"camera": PerspectiveCameras(R=R, T=T, device=DEVICE),
"fg_probability": torch.randint(
high=2, size=(N, 1, H, W), device=DEVICE
).float(),
"depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1,
"mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(),
"sequence_name": ["sequence"] * N,
"image_rgb": torch.rand((N, 1, H, W), device=DEVICE),
}
def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
"""Return non deterministic indexes to mock safe_multinomial
Args:
input: tensor of shape [B, n] containing non-negative values;
rows are interpreted as unnormalized event probabilities
in categorical distributions.
num_samples: number of samples to take.
Returns:
Tensor of shape [B, num_samples]
"""
batch_size = input.shape[0]
return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE)
class TestOverfitModel(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_overfit_model_vs_generic_model_with_batch_size_one(self):
"""In this test we compare OverfitModel to GenericModel behavior.
We use a Nerf setup (2 rendering passes).
OverfitModel is a specific case of GenericModel. Hence, with the same inputs,
they should provide the exact same results.
"""
expand_args_fields(OverfitModel)
expand_args_fields(GenericModel)
batch_size, image_height, image_width = 1, 80, 80
assert batch_size == 1
overfit_model = OverfitModel(
render_image_height=image_height,
render_image_width=image_width,
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
# To avoid randomization to compare the outputs of our model
# we deactivate the stratified_point_sampling_training
raysampler_AdaptiveRaySampler_args={
"stratified_point_sampling_training": False
},
global_encoder_class_type="SequenceAutodecoder",
global_encoder_SequenceAutodecoder_args={
"autodecoder_args": {
"n_instances": 1000,
"init_scale": 1.0,
"encoding_dim": 64,
}
},
)
generic_model = GenericModel(
render_image_height=image_height,
render_image_width=image_width,
n_train_target_views=batch_size,
num_passes=2,
# To avoid randomization to compare the outputs of our model
# we deactivate the stratified_point_sampling_training
raysampler_AdaptiveRaySampler_args={
"stratified_point_sampling_training": False
},
global_encoder_class_type="SequenceAutodecoder",
global_encoder_SequenceAutodecoder_args={
"autodecoder_args": {
"n_instances": 1000,
"init_scale": 1.0,
"encoding_dim": 64,
}
},
)
# Check if they do share the number of parameters
num_params_mvm = sum(p.numel() for p in overfit_model.parameters())
num_params_gm = sum(p.numel() for p in generic_model.parameters())
self.assertEqual(num_params_mvm, num_params_gm)
# Adapt the mapping from generic model to overfit model
mapping_om_from_gm = {
key.replace("_implicit_functions.0._fn", "implicit_function").replace(
"_implicit_functions.1._fn", "coarse_implicit_function"
): val
for key, val in generic_model.state_dict().items()
}
# Copy parameters from generic_model to overfit_model
overfit_model.load_state_dict(mapping_om_from_gm)
overfit_model.to(DEVICE)
generic_model.to(DEVICE)
inputs_ = _generate_fake_inputs(batch_size, image_height, image_width)
# training forward pass
overfit_model.train()
generic_model.train()
with patch(
"pytorch3d.renderer.implicit.raysampling._safe_multinomial",
side_effect=mock_safe_multinomial,
):
train_preds_om = overfit_model(
**inputs_,
evaluation_mode=EvaluationMode.TRAINING,
)
train_preds_gm = generic_model(
**inputs_,
evaluation_mode=EvaluationMode.TRAINING,
)
self.assertTrue(len(train_preds_om) == len(train_preds_gm))
self.assertTrue(train_preds_om["objective"].isfinite().item())
# We avoid all the randomization and the weights are the same
# The objective should be the same
self.assertTrue(
torch.allclose(train_preds_om["objective"], train_preds_gm["objective"])
)
# Test if the evaluation works
overfit_model.eval()
generic_model.eval()
with torch.no_grad():
eval_preds_om = overfit_model(
**inputs_,
evaluation_mode=EvaluationMode.EVALUATION,
)
eval_preds_gm = generic_model(
**inputs_,
evaluation_mode=EvaluationMode.EVALUATION,
)
self.assertEqual(
eval_preds_om["images_render"].shape,
(batch_size, 3, image_height, image_width),
)
self.assertTrue(
torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"])
)
self.assertTrue(
torch.allclose(
eval_preds_om["images_render"], eval_preds_gm["images_render"]
)
)
def test_overfit_model_check_share_weights(self):
model = OverfitModel(share_implicit_function_across_passes=True)
for p1, p2 in zip(
model.implicit_function.parameters(),
model.coarse_implicit_function.parameters(),
):
self.assertEqual(id(p1), id(p2))
model.to(DEVICE)
inputs_ = _generate_fake_inputs(2, 80, 80)
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
def test_overfit_model_check_no_share_weights(self):
model = OverfitModel(
share_implicit_function_across_passes=False,
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={
"transformer_dim_down_factor": 1.0,
"n_hidden_neurons_xyz": 256,
"n_layers_xyz": 8,
"append_xyz": (5,),
},
)
for p1, p2 in zip(
model.implicit_function.parameters(),
model.coarse_implicit_function.parameters(),
):
self.assertNotEqual(id(p1), id(p2))
model.to(DEVICE)
inputs_ = _generate_fake_inputs(2, 80, 80)
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
def test_overfit_model_coarse_implicit_function_is_none(self):
model = OverfitModel(
share_implicit_function_across_passes=False,
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None,
)
self.assertIsNone(model.coarse_implicit_function)
model.to(DEVICE)
inputs_ = _generate_fake_inputs(2, 80, 80)
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
class TestUtils(unittest.TestCase):
def test_prepare_inputs_wrong_num_dim(self):
img = torch.randn(3, 3, 3)
with self.assertRaises(ValueError) as context:
img, fg_prob, depth_map = preprocess_input(
img, None, None, True, True, 0.5, (0.0, 0.0, 0.0)
)
self.assertEqual(
"Model received unbatched inputs. "
+ "Perhaps they came from a FrameData which had not been collated.",
context.exception,
)
def test_prepare_inputs_mask_image_true(self):
batch, channels, height, width = 2, 3, 10, 10
img = torch.ones(batch, channels, height, width)
# Create a mask on the lower triangular matrix
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
out_img, out_fg_prob, out_depth_map = preprocess_input(
img, fg_prob, None, True, False, 0.3, (0.0, 0.0, 0.0)
)
self.assertTrue(torch.equal(out_img, torch.tril(img)))
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
self.assertIsNone(out_depth_map)
def test_prepare_inputs_mask_depth_true(self):
batch, channels, height, width = 2, 3, 10, 10
img = torch.ones(batch, channels, height, width)
depth_map = torch.randn(batch, channels, height, width)
# Create a mask on the lower triangular matrix
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
out_img, out_fg_prob, out_depth_map = preprocess_input(
img, fg_prob, depth_map, False, True, 0.3, (0.0, 0.0, 0.0)
)
self.assertTrue(torch.equal(out_img, img))
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
self.assertTrue(torch.equal(out_depth_map, torch.tril(depth_map)))
def test_weighted_sum_losses(self):
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
weights = {"a": 2.0, "b": 0.0}
loss = weighted_sum_losses(preds, weights)
self.assertEqual(loss, 4.0)
def test_weighted_sum_losses_raise_warning(self):
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
weights = {"c": 2.0, "d": 2.0}
self.assertIsNone(weighted_sum_losses(preds, weights))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment