Commit 9540c290 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Make Module.__init__ automatic

Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
parent 97f8f9bf
...@@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry ...@@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
class XRayRenderer(BaseRenderer, torch.nn.Module): class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64 n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self): def __post_init__(self):
super().__init__()
# custom initialization # custom initialization
def forward( def forward(
......
...@@ -130,7 +130,7 @@ def evaluate_dbir_for_category( ...@@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset") raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model # init the simple DBIR model
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden model = ModelDBIR(
render_image_width=image_size, render_image_width=image_size,
render_image_height=image_size, render_image_height=image_size,
bg_color=bg_color, bg_color=bg_color,
......
...@@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module): ...@@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
# the training loop. # the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"]) log_vars: List[str] = field(default_factory=lambda: ["objective"])
def __init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
*, # force keyword-only arguments *, # force keyword-only arguments
......
...@@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module): ...@@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
Base class for an extractor of a set of features from images. Base class for an extractor of a set of features from images.
""" """
def __init__(self):
super().__init__()
def get_feat_dims(self) -> int: def get_feat_dims(self) -> int:
""" """
Returns: Returns:
......
...@@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase): ...@@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
feature_rescale: float = 1.0 feature_rescale: float = 1.0
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.normalize_image: if self.normalize_image:
# register buffers needed to normalize the image # register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
......
...@@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
) )
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.view_pooler_enabled: if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None: if self.image_feature_extractor_class_type is None:
raise ValueError( raise ValueError(
......
...@@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module): ...@@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
ignore_input: bool = False ignore_input: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.n_instances <= 0: if self.n_instances <= 0:
raise ValueError(f"Invalid n_instances {self.n_instances}") raise ValueError(f"Invalid n_instances {self.n_instances}")
......
...@@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase): ...@@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
(`SequenceAutodecoder`). (`SequenceAutodecoder`).
""" """
def __init__(self) -> None:
super().__init__()
def get_encoding_dim(self): def get_encoding_dim(self):
""" """
Returns the dimensionality of the returned encoding. Returns the dimensionality of the returned encoding.
...@@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1 ...@@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
autodecoder: Autodecoder autodecoder: Autodecoder
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def get_encoding_dim(self): def get_encoding_dim(self):
...@@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module): ...@@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time_divisor: float = 1.0 time_divisor: float = 1.0
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions, n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input, append_input=self.append_input,
......
...@@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase ...@@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
class ImplicitFunctionBase(ABC, ReplaceableBase): class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
......
...@@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): ...@@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
space and transforms it into the required quantity (for example density and color). space and transforms it into the required quantity (for example density and color).
""" """
def __post_init__(self):
super().__init__()
def forward( def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase): ...@@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
operation: DecoderActivation = DecoderActivation.IDENTITY operation: DecoderActivation = DecoderActivation.IDENTITY
def __post_init__(self): def __post_init__(self):
super().__post_init__()
if self.operation not in [ if self.operation not in [
DecoderActivation.RELU, DecoderActivation.RELU,
DecoderActivation.SOFTPLUS, DecoderActivation.SOFTPLUS,
...@@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): ...@@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
use_xavier_init: bool = True use_xavier_init: bool = True
def __post_init__(self): def __post_init__(self):
super().__init__()
try: try:
last_activation = { last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True), DecoderActivation.RELU: torch.nn.ReLU(True),
...@@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase): ...@@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
network: MLPWithInputSkips network: MLPWithInputSkips
def __post_init__(self): def __post_init__(self):
super().__post_init__()
run_auto_creation(self) run_auto_creation(self)
def forward( def forward(
......
...@@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): ...@@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
encoding_dim: int = 0 encoding_dim: int = 0
def __post_init__(self): def __post_init__(self):
super().__init__()
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None self.embed_fn = None
......
...@@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): ...@@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
""" """
def __post_init__(self): def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates # The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for # to a representation that is more suitable for
# processing with a deep neural network. # processing with a deep neural network.
......
...@@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module): ...@@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
raymarch_function: Any = None raymarch_function: Any = None
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True self.n_harmonic_functions, append_input=True
) )
...@@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module): ...@@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
ray_dir_in_camera_coords: bool = False ray_dir_in_camera_coords: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True self.n_harmonic_functions, append_input=True
) )
...@@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module): ...@@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
xyz_in_camera_coords: bool = False xyz_in_camera_coords: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = ( raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static( HarmonicEmbedding.get_output_dim_static(
self.in_features, self.in_features,
...@@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): ...@@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_raymarch_function(self) -> None: def create_raymarch_function(self) -> None:
...@@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): ...@@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_hypernet(self) -> None: def create_hypernet(self) -> None:
......
...@@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): ...@@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
) )
def __post_init__(self): def __post_init__(self):
super().__init__()
if 0 not in self.resolution_changes: if 0 not in self.resolution_changes:
raise ValueError("There has to be key `0` in `resolution_changes`.") raise ValueError("There has to be key `0` in `resolution_changes`.")
...@@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module): ...@@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
param_groups: Dict[str, str] = field(default_factory=lambda: {}) param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid. n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes(epoch=0) shapes = self.voxel_grid.get_shapes(epoch=0)
......
...@@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): ...@@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
volume_cropping_epochs: Tuple[int, ...] = () volume_cropping_epochs: Tuple[int, ...] = ()
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self) run_auto_creation(self)
# pyre-ignore[16] # pyre-ignore[16]
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
......
...@@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module): ...@@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
depend on the model's parameters. depend on the model's parameters.
""" """
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module): ...@@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
`forward()` method produces losses and other metrics. `forward()` method produces losses and other metrics.
""" """
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
raymarched: RendererOutput, raymarched: RendererOutput,
......
...@@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase): ...@@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1 max_points: int = -1
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
*, # force keyword-only arguments *, # force keyword-only arguments
......
...@@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase): ...@@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
Base class for all Renderer implementations. Base class for all Renderer implementations.
""" """
def __init__(self) -> None:
super().__init__()
def requires_object_mask(self) -> bool: def requires_object_mask(self) -> bool:
""" """
Whether `forward` needs the object_mask. Whether `forward` needs the object_mask.
......
...@@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): ...@@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
verbose: bool = False verbose: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell( self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels, input_size=self.n_feature_channels,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 ...@@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
return_weights: bool = False return_weights: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._refiners = { self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner( EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training, n_pts_per_ray=self.n_pts_per_ray_fine_training,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment