Commit 47d06c89 authored by David Novotny's avatar David Novotny Committed by Facebook GitHub Bot
Browse files

ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
parent bef959c7
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation
from pytorch3d.renderer.cameras import CamerasBase
from .feature_aggregator import FeatureAggregatorBase
from .view_sampler import ViewSampler
# pyre-ignore: 13
class ViewPooler(Configurable, torch.nn.Module):
"""
Implements sampling of image-based features at the 2d projections of a set
of 3D points, and a subsequent aggregation of the resulting set of features
per-point.
Args:
view_sampler: An instance of ViewSampler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
"""
view_sampler: ViewSampler
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
"""
Returns the final dimensionality of the output aggregated features.
Args:
feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding
to the `feats_sampled` argument of `feature_aggregator,forward`,
or an `int` representing the sum of dimensionalities of each `t_i`.
Returns:
aggregated_feature_dim: The final dimensionality of the output
aggregated features.
"""
return self.feature_aggregator.get_aggregated_feature_dim(feats)
def has_aggregation(self):
"""
Specifies whether the `feature_aggregator` reduces the output `reduce_dim`
dimension to 1.
Returns:
has_aggregation: `True` if `reduce_dim==1`, else `False`.
"""
return self.feature_aggregator.has_aggregation()
def forward(
self,
*, # force kw args
pts: torch.Tensor,
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
camera: CamerasBase,
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
feats: Dict[str, torch.Tensor],
masks: Optional[torch.Tensor],
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Project each point cloud from a batch of point clouds to corresponding
input cameras, sample features at the 2D projection locations in a batch
of source images, and aggregate the pointwise sampled features.
Args:
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
from which `pts` were extracted, or a list of string names.
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
corresponding to cameras in `camera`, or a list of string names.
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
masks: `[n_cameras x 1 x H x W]`, define valid image regions
for sampling `feats`.
Returns:
feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor
of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))`
containing the aggregated features. `reduce_dim` depends on
the specific feature aggregator implementation and typically
equals 1 or `n_cameras`.
If `feature_aggregator.concatenate_output==False`, the aggregator
does not concatenate the aggregated features and returns a dictionary
of per-feature aggregations `{f_i: t_i_aggregated}` instead.
Each `t_i_aggregated` is of shape
`(pts_batch, reduce_dim, n_pts, aggr_dim_i)`.
"""
# (1) Sample features and masks at the ray points
sampled_feats, sampled_masks = self.view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
# (2) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
feats_aggregated = self.feature_aggregator( # noqa: E731
sampled_feats,
sampled_masks,
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
return feats_aggregated
......@@ -8,7 +8,6 @@ bg_color:
- 0.0
- 0.0
- 0.0
view_pool: false
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
......@@ -17,7 +16,8 @@ n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
renderer_class_type: LSTMRenderer
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
image_feature_extractor_enabled: true
view_pooler_enabled: true
implicit_function_class_type: IdrFeatureField
loss_weights:
loss_rgb_mse: 1.0
......@@ -91,15 +91,17 @@ image_feature_extractor_args:
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
......
......@@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
......@@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import (
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
from .common_resources import provide_lpips_vgg
else:
from common_resources import provide_lpips_vgg # noqa
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
......@@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase):
self.maxDiff = None
def test_create_gm(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
def test_create_gm_overrides(self):
def _test_create_gm_overrides(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
gm.view_pooler.feature_aggregator,
AngleWeightedIdentityFeatureAggregator,
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
......
......@@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
cfg = _load_model_config_from_yaml(str(config_file))
model = GenericModel(**cfg)
model.to(device)
self._one_model_test(model, device, eval_test=True)
self._one_model_test(
model,
device,
eval_test=True,
bw_test=True,
)
def _one_model_test(
self,
......@@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
device,
n_train_cameras: int = 5,
eval_test: bool = True,
bw_test: bool = True,
):
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
......@@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase):
**random_args,
evaluation_mode=EvaluationMode.TRAINING,
)
self.assertGreater(train_preds["objective"].item(), 0)
train_preds["objective"].backward()
self.assertTrue(
train_preds["objective"].isfinite().item()
) # check finiteness of the objective
if bw_test:
train_preds["objective"].backward()
if eval_test:
model.eval()
......
......@@ -9,7 +9,7 @@ import unittest
import pytorch3d as pt3d
import torch
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler
from pytorch3d.implicitron.tools.config import expand_args_fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment