"src/vscode:/vscode.git/clone" did not exist on "cfa80974ddbf9a88d5bd7b6db322e11c876feef8"
Commit 47d06c89 authored by David Novotny's avatar David Novotny Committed by Facebook GitHub Bot
Browse files

ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
parent bef959c7
...@@ -63,8 +63,9 @@ generic_model_args: ...@@ -63,8 +63,9 @@ generic_model_args:
n_pts_per_ray_fine_evaluation: 64 n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true append_coarse_samples_to_fine: true
density_noise_std_train: 1.0 density_noise_std_train: 1.0
view_sampler_args: view_pooler_args:
masked_sampling: false view_sampler_args:
masked_sampling: false
image_feature_extractor_args: image_feature_extractor_args:
stages: stages:
- 1 - 1
......
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args: image_feature_extractor_args:
add_images: true add_images: true
add_masks: true add_masks: true
......
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args: image_feature_extractor_args:
add_images: true add_images: true
add_masks: true add_masks: true
......
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args: image_feature_extractor_args:
stages: stages:
- 1 - 1
...@@ -11,6 +12,7 @@ generic_model_args: ...@@ -11,6 +12,7 @@ generic_model_args:
name: resnet34 name: resnet34
normalize_image: true normalize_image: true
pretrained: true pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args: view_pooler_args:
reduction_functions: feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
- AVG reduction_functions:
- AVG
...@@ -11,7 +11,6 @@ generic_model_args: ...@@ -11,7 +11,6 @@ generic_model_args:
num_passes: 1 num_passes: 1
output_rasterized_mc: true output_rasterized_mc: true
sampling_mode_training: mask_sample sampling_mode_training: mask_sample
view_pool: false
sequence_autodecoder_args: sequence_autodecoder_args:
n_instances: 20000 n_instances: 20000
init_scale: 1.0 init_scale: 1.0
......
...@@ -3,7 +3,7 @@ defaults: ...@@ -3,7 +3,7 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: false view_pooler_enabled: false
sequence_autodecoder_args: sequence_autodecoder_args:
n_instances: 20000 n_instances: 20000
encoding_dim: 256 encoding_dim: 256
...@@ -5,6 +5,6 @@ defaults: ...@@ -5,6 +5,6 @@ defaults:
clip_grad: 1.0 clip_grad: 1.0
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: true view_pooler_enabled: true
raysampler_args: raysampler_args:
n_rays_per_image_sampled_from_mask: 850 n_rays_per_image_sampled_from_mask: 850
...@@ -4,7 +4,6 @@ defaults: ...@@ -4,7 +4,6 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: true
raysampler_args: raysampler_args:
n_rays_per_image_sampled_from_mask: 800 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32 n_pts_per_ray_training: 32
...@@ -13,4 +12,6 @@ generic_model_args: ...@@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16 n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: IdentityFeatureAggregator view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator
...@@ -4,7 +4,6 @@ defaults: ...@@ -4,7 +4,6 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: true
raysampler_args: raysampler_args:
n_rays_per_image_sampled_from_mask: 800 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32 n_pts_per_ray_training: 32
...@@ -13,4 +12,6 @@ generic_model_args: ...@@ -13,4 +12,6 @@ generic_model_args:
n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16 n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator view_pooler_enabled: true
view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
...@@ -3,7 +3,7 @@ defaults: ...@@ -3,7 +3,7 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: false view_pooler_enabled: false
n_train_target_views: -1 n_train_target_views: -1
num_passes: 1 num_passes: 1
loss_weights: loss_weights:
......
...@@ -4,7 +4,6 @@ defaults: ...@@ -4,7 +4,6 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 32000 chunk_size_grid: 32000
view_pool: true
num_passes: 1 num_passes: 1
n_train_target_views: -1 n_train_target_views: -1
loss_weights: loss_weights:
...@@ -25,6 +24,7 @@ generic_model_args: ...@@ -25,6 +24,7 @@ generic_model_args:
stratified_point_sampling_evaluation: false stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction implicit_function_class_type: SRNImplicitFunction
view_pooler_enabled: true
solver_args: solver_args:
breed: adam breed: adam
lr: 5.0e-05 lr: 5.0e-05
...@@ -9,7 +9,7 @@ generic_model_args: ...@@ -9,7 +9,7 @@ generic_model_args:
loss_eikonal: 0.1 loss_eikonal: 0.1
chunk_size_grid: 65536 chunk_size_grid: 65536
num_passes: 1 num_passes: 1
view_pool: false view_pooler_enabled: false
implicit_function_IdrFeatureField_args: implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6 n_harmonic_functions_xyz: 6
bias: 0.6 bias: 0.6
......
...@@ -4,6 +4,6 @@ defaults: ...@@ -4,6 +4,6 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: true view_pooler_enabled: true
raysampler_args: raysampler_args:
n_rays_per_image_sampled_from_mask: 850 n_rays_per_image_sampled_from_mask: 850
...@@ -4,7 +4,7 @@ defaults: ...@@ -4,7 +4,7 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pool: true view_pooler_enabled: true
implicit_function_class_type: NeRFormerImplicitFunction implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args: raysampler_args:
n_rays_per_image_sampled_from_mask: 800 n_rays_per_image_sampled_from_mask: 800
...@@ -13,4 +13,5 @@ generic_model_args: ...@@ -13,4 +13,5 @@ generic_model_args:
renderer_MultiPassEmissionAbsorptionRenderer_args: renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16 n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator
...@@ -4,7 +4,7 @@ defaults: ...@@ -4,7 +4,7 @@ defaults:
generic_model_args: generic_model_args:
num_passes: 1 num_passes: 1
chunk_size_grid: 32000 chunk_size_grid: 32000
view_pool: false view_pooler_enabled: false
loss_weights: loss_weights:
loss_rgb_mse: 200.0 loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0 loss_prev_stage_rgb_mse: 0.0
......
...@@ -5,7 +5,7 @@ defaults: ...@@ -5,7 +5,7 @@ defaults:
generic_model_args: generic_model_args:
num_passes: 1 num_passes: 1
chunk_size_grid: 32000 chunk_size_grid: 32000
view_pool: true view_pooler_enabled: true
loss_weights: loss_weights:
loss_rgb_mse: 200.0 loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0 loss_prev_stage_rgb_mse: 0.0
......
...@@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa ...@@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler from .renderer.ray_sampler import RaySampler
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooling.feature_aggregation import FeatureAggregatorBase from .view_pooler.view_pooler import ViewPooler
from .view_pooling.view_sampling import ViewSampler
STD_LOG_VARS = ["objective", "epoch", "sec/it"] STD_LOG_VARS = ["objective", "epoch", "sec/it"]
...@@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
registry. registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s). generate the images from the target view(s).
image_feature_extractor_enabled: If `True`, constructs and enables
the `image_feature_extractor` object.
image_feature_extractor: A module for extrating features from an input image. image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
view_pooler: An instance of ViewPooler which is used for sampling of
image-based features at the 2D projections of a set image-based features at the 2D projections of a set
of 3D points. of 3D points and aggregating the sampled features.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
implicit_function_class_type: The type of implicit function to use which implicit_function_class_type: The type of implicit function to use which
is available in the global registry. is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
...@@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
mask_threshold: float = 0.5 mask_threshold: float = 0.5
output_rasterized_mc: bool = False output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
view_pool: bool = False
num_passes: int = 1 num_passes: int = 1
chunk_size_grid: int = 4096 chunk_size_grid: int = 4096
render_features_dimensions: int = 3 render_features_dimensions: int = 3
...@@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
renderer: BaseRenderer renderer: BaseRenderer
# ---- view sampling settings - used if view_pool=True # ---- image feature extractor settings
# (This is only created if view_pool is False) image_feature_extractor_enabled: bool = False
image_feature_extractor: ResNetFeatureExtractor image_feature_extractor: Optional[ResNetFeatureExtractor]
view_sampler: ViewSampler # ---- view pooler settings
# ---- ---- view sampling feature aggregator settings view_pooler_enabled: bool = False
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator" view_pooler: Optional[ViewPooler]
feature_aggregator: FeatureAggregatorBase
# ---- implicit function settings # ---- implicit function settings
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction" implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
...@@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# custom_args hold additional arguments to the implicit function. # custom_args hold additional arguments to the implicit function.
custom_args = {} custom_args = {}
if self.view_pool: if self.image_feature_extractor_enabled:
# (2) Extract features for the image
img_feats = self.image_feature_extractor( # pyre-fixme[29]
image_rgb, fg_probability
)
if self.view_pooler_enabled:
if sequence_name is None: if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling") raise ValueError("sequence_name must be provided for view pooling")
# (2) Extract features for the image if not self.image_feature_extractor_enabled:
img_feats = self.image_feature_extractor(image_rgb, fg_probability) raise ValueError(
"image_feature_extractor has to be enabled for for view pooling"
# (3) Sample features and masks at the ray points + " (I.e. set self.image_feature_extractor_enabled=True)."
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731 )
pts=pts,
seq_id_pts=sequence_name[:n_targets], # (3-4) Sample features and masks at the ray points.
camera=camera, # Aggregate features from multiple views.
seq_id_camera=sequence_name, def curried_viewpooler(pts):
feats=img_feats, return self.view_pooler(
masks=mask_crop, pts=pts,
) # returns feats_sampled, masks_sampled seq_id_pts=sequence_name[:n_targets],
camera=camera,
# (4) Aggregate features from multiple views seq_id_camera=sequence_name,
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. feats=img_feats,
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731 masks=mask_crop,
*curried_view_sampler(pts=pts), )
pts=pts,
camera=camera, custom_args["fun_viewpool"] = curried_viewpooler
) # TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args["fun_viewpool"] = curried_view_pool
global_code = None global_code = None
if self.sequence_autodecoder.n_instances > 0: if self.sequence_autodecoder.n_instances > 0:
...@@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
def _get_viewpooled_feature_dim(self): def _get_viewpooled_feature_dim(self):
return ( return (
self.feature_aggregator.get_aggregated_feature_dim( self.view_pooler.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims() self.image_feature_extractor.get_feat_dims()
) )
if self.view_pool if self.view_pooler_enabled
else 0 else 0
) )
...@@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"object_bounding_sphere" "object_bounding_sphere"
] = self.raysampler_args["scene_extent"] ] = self.raysampler_args["scene_extent"]
def create_image_feature_extractor(self): def create_view_pooler(self):
""" """
Custom creation function called by run_auto_creation so that the Custom creation function called by run_auto_creation checking
image_feature_extractor is not created if it is not be needed. that image_feature_extractor is enabled when view_pooler is enabled.
""" """
if self.view_pool: if self.view_pooler_enabled:
self.image_feature_extractor = ResNetFeatureExtractor( if not self.image_feature_extractor_enabled:
**self.image_feature_extractor_args raise ValueError(
) "image_feature_extractor has to be enabled for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
self.view_pooler = ViewPooler(**self.view_pooler_args)
else:
self.view_pooler = None
def create_implicit_function(self) -> None: def create_implicit_function(self) -> None:
""" """
...@@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
) )
if implicit_function_type.requires_pooling_without_aggregation(): if implicit_function_type.requires_pooling_without_aggregation():
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions") if self.view_pooler_enabled and self.view_pooler.has_aggregation():
if not self.view_pool or has_aggregation:
raise ValueError( raise ValueError(
"Chosen implicit function requires view pooling without aggregation." "The chosen implicit function requires view pooling without aggregation."
) )
config_name = f"implicit_function_{self.implicit_function_class_type}_args" config_name = f"implicit_function_{self.implicit_function_class_type}_args"
config = getattr(self, config_name, None) config = getattr(self, config_name, None)
......
...@@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): ...@@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
return (img - self._resnet_mean) / self._resnet_std return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self, size_dict: bool = False): def get_feat_dims(self) -> int:
if size_dict: return (
return copy.deepcopy(self._feat_dim) sum(self._feat_dim.values()) # pyre-fixme[29]
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na... if len(self._feat_dim) > 0 # pyre-fixme[6]
return sum(self._feat_dim.values()) else 0
)
def forward( def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
......
...@@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union ...@@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooling.view_sampling import ( from pytorch3d.implicitron.models.view_pooler.view_sampler import (
cameras_points_cartesian_product, cameras_points_cartesian_product,
) )
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
...@@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase): ...@@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
"""
Returns the final dimensionality of the output aggregated features.
Args:
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
to the `feats_sampled` argument of `forward`,
or an `int` representing the sum of dimensionalities of each `t_i`.
Returns:
aggregated_feature_dim: The final dimensionality of the output
aggregated features.
"""
raise NotImplementedError()
def has_aggregation(self) -> bool:
"""
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
Returns:
has_aggregation: `True` if `reduce_dim==1`, else `False`.
"""
return hasattr(self, "reduction_functions")
@registry.register @registry.register
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
...@@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): ...@@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(
return _get_reduction_aggregator_feature_dim(feats, []) self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward( def forward(
self, self,
...@@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): ...@@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward( def forward(
self, self,
...@@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator ...@@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(
feats_or_feats_dim, self.reduction_functions
)
def forward( def forward(
self, self,
...@@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB ...@@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(
return _get_reduction_aggregator_feature_dim(feats, []) self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment