Commit 47d06c89 authored by David Novotny's avatar David Novotny Committed by Facebook GitHub Bot
Browse files

ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
parent bef959c7
......@@ -63,8 +63,9 @@ generic_model_args:
n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true
density_noise_std_train: 1.0
view_sampler_args:
masked_sampling: false
view_pooler_args:
view_sampler_args:
masked_sampling: false
image_feature_extractor_args:
stages:
- 1
......
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
add_images: true
add_masks: true
......
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
add_images: true
add_masks: true
......
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
stages:
- 1
......@@ -11,6 +12,7 @@ generic_model_args:
name: resnet34
normalize_image: true
pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG
view_pooler_args:
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG
......@@ -11,7 +11,6 @@ generic_model_args:
num_passes: 1
output_rasterized_mc: true
sampling_mode_training: mask_sample
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
init_scale: 1.0
......
......@@ -3,7 +3,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
view_pooler_enabled: false
sequence_autodecoder_args:
n_instances: 20000
encoding_dim: 256
......@@ -5,6 +5,6 @@ defaults:
clip_grad: 1.0
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850
......@@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
......@@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: IdentityFeatureAggregator
view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator
......@@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
......@@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
......@@ -3,7 +3,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
view_pooler_enabled: false
n_train_target_views: -1
num_passes: 1
loss_weights:
......
......@@ -4,7 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 32000
view_pool: true
num_passes: 1
n_train_target_views: -1
loss_weights:
......@@ -25,6 +24,7 @@ generic_model_args:
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
view_pooler_enabled: true
solver_args:
breed: adam
lr: 5.0e-05
......@@ -9,7 +9,7 @@ generic_model_args:
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
view_pool: false
view_pooler_enabled: false
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6
......
......@@ -4,6 +4,6 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850
......@@ -4,7 +4,7 @@ defaults:
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
view_pooler_enabled: true
implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
......@@ -13,4 +13,5 @@ generic_model_args:
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator
view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator
......@@ -4,7 +4,7 @@ defaults:
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
view_pooler_enabled: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
......
......@@ -5,7 +5,7 @@ defaults:
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
view_pooler_enabled: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
......
......@@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooling.feature_aggregation import FeatureAggregatorBase
from .view_pooling.view_sampling import ViewSampler
from .view_pooler.view_pooler import ViewPooler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
......@@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor_enabled: If `True`, constructs and enables
the `image_feature_extractor` object.
image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
view_pooler: An instance of ViewPooler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
of 3D points and aggregating the sampled features.
implicit_function_class_type: The type of implicit function to use which
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
......@@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
mask_threshold: float = 0.5
output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
view_pool: bool = False
num_passes: int = 1
chunk_size_grid: int = 4096
render_features_dimensions: int = 3
......@@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
renderer: BaseRenderer
# ---- view sampling settings - used if view_pool=True
# (This is only created if view_pool is False)
image_feature_extractor: ResNetFeatureExtractor
view_sampler: ViewSampler
# ---- ---- view sampling feature aggregator settings
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
# ---- image feature extractor settings
image_feature_extractor_enabled: bool = False
image_feature_extractor: Optional[ResNetFeatureExtractor]
# ---- view pooler settings
view_pooler_enabled: bool = False
view_pooler: Optional[ViewPooler]
# ---- implicit function settings
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
......@@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# custom_args hold additional arguments to the implicit function.
custom_args = {}
if self.view_pool:
if self.image_feature_extractor_enabled:
# (2) Extract features for the image
img_feats = self.image_feature_extractor( # pyre-fixme[29]
image_rgb, fg_probability
)
if self.view_pooler_enabled:
if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling")
# (2) Extract features for the image
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
# (3) Sample features and masks at the ray points
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
) # returns feats_sampled, masks_sampled
# (4) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
*curried_view_sampler(pts=pts),
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args["fun_viewpool"] = curried_view_pool
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
# (3-4) Sample features and masks at the ray points.
# Aggregate features from multiple views.
def curried_viewpooler(pts):
return self.view_pooler(
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
)
custom_args["fun_viewpool"] = curried_viewpooler
global_code = None
if self.sequence_autodecoder.n_instances > 0:
......@@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
def _get_viewpooled_feature_dim(self):
return (
self.feature_aggregator.get_aggregated_feature_dim(
self.view_pooler.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
if self.view_pool
if self.view_pooler_enabled
else 0
)
......@@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"object_bounding_sphere"
] = self.raysampler_args["scene_extent"]
def create_image_feature_extractor(self):
def create_view_pooler(self):
"""
Custom creation function called by run_auto_creation so that the
image_feature_extractor is not created if it is not be needed.
Custom creation function called by run_auto_creation checking
that image_feature_extractor is enabled when view_pooler is enabled.
"""
if self.view_pool:
self.image_feature_extractor = ResNetFeatureExtractor(
**self.image_feature_extractor_args
)
if self.view_pooler_enabled:
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
self.view_pooler = ViewPooler(**self.view_pooler_args)
else:
self.view_pooler = None
def create_implicit_function(self) -> None:
"""
......@@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
)
if implicit_function_type.requires_pooling_without_aggregation():
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
if not self.view_pool or has_aggregation:
if self.view_pooler_enabled and self.view_pooler.has_aggregation():
raise ValueError(
"Chosen implicit function requires view pooling without aggregation."
"The chosen implicit function requires view pooling without aggregation."
)
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
config = getattr(self, config_name, None)
......
......@@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self, size_dict: bool = False):
if size_dict:
return copy.deepcopy(self._feat_dim)
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na...
return sum(self._feat_dim.values())
def get_feat_dims(self) -> int:
return (
sum(self._feat_dim.values()) # pyre-fixme[29]
if len(self._feat_dim) > 0 # pyre-fixme[6]
else 0
)
def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
......
......@@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
cameras_points_cartesian_product,
)
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
......@@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
"""
raise NotImplementedError()
@abstractmethod
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
"""
Returns the final dimensionality of the output aggregated features.
Args:
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
to the `feats_sampled` argument of `forward`,
or an `int` representing the sum of dimensionalities of each `t_i`.
Returns:
aggregated_feature_dim: The final dimensionality of the output
aggregated features.
"""
raise NotImplementedError()
def has_aggregation(self) -> bool:
"""
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
Returns:
has_aggregation: `True` if `reduce_dim==1`, else `False`.
"""
return hasattr(self, "reduction_functions")
@registry.register
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
......@@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward(
self,
......@@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward(
self,
......@@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward(
self,
......@@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment