Commit 9540c290 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Make Module.__init__ automatic

Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
parent 97f8f9bf
......@@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self):
super().__init__()
# custom initialization
def forward(
......
......@@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
model = ModelDBIR(
render_image_width=image_size,
render_image_height=image_size,
bg_color=bg_color,
......
......@@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
# the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"])
def __init__(self) -> None:
super().__init__()
def forward(
self,
*, # force keyword-only arguments
......
......@@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
Base class for an extractor of a set of features from images.
"""
def __init__(self):
super().__init__()
def get_feat_dims(self) -> int:
"""
Returns:
......
......@@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
feature_rescale: float = 1.0
def __post_init__(self):
super().__init__()
if self.normalize_image:
# register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
......
......@@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
)
def __post_init__(self):
super().__init__()
if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
raise ValueError(
......
......@@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
ignore_input: bool = False
def __post_init__(self):
super().__init__()
if self.n_instances <= 0:
raise ValueError(f"Invalid n_instances {self.n_instances}")
......
......@@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
(`SequenceAutodecoder`).
"""
def __init__(self) -> None:
super().__init__()
def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
......@@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
autodecoder: Autodecoder
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_encoding_dim(self):
......@@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time_divisor: float = 1.0
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
......
......@@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
......
......@@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
space and transforms it into the required quantity (for example density and color).
"""
def __post_init__(self):
super().__init__()
def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
......@@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
operation: DecoderActivation = DecoderActivation.IDENTITY
def __post_init__(self):
super().__post_init__()
if self.operation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
......@@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
use_xavier_init: bool = True
def __post_init__(self):
super().__init__()
try:
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
......@@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
network: MLPWithInputSkips
def __post_init__(self):
super().__post_init__()
run_auto_creation(self)
def forward(
......
......@@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
encoding_dim: int = 0
def __post_init__(self):
super().__init__()
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None
......
......@@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
"""
def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for
# processing with a deep neural network.
......
......@@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
raymarch_function: Any = None
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
......@@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
ray_dir_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
......@@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
xyz_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
......@@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def create_raymarch_function(self) -> None:
......@@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def create_hypernet(self) -> None:
......
......@@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
)
def __post_init__(self):
super().__init__()
if 0 not in self.resolution_changes:
raise ValueError("There has to be key `0` in `resolution_changes`.")
......@@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes(epoch=0)
......
......@@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
volume_cropping_epochs: Tuple[int, ...] = ()
def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self)
# pyre-ignore[16]
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
......
......@@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
depend on the model's parameters.
"""
def __post_init__(self) -> None:
super().__init__()
def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]:
......@@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
`forward()` method produces losses and other metrics.
"""
def __post_init__(self) -> None:
super().__init__()
def forward(
self,
raymarched: RendererOutput,
......
......@@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1
def __post_init__(self):
super().__init__()
def forward(
self,
*, # force keyword-only arguments
......
......@@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
Base class for all Renderer implementations.
"""
def __init__(self) -> None:
super().__init__()
def requires_object_mask(self) -> bool:
"""
Whether `forward` needs the object_mask.
......
......@@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
verbose: bool = False
def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels,
hidden_size=self.hidden_size,
......
......@@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
return_weights: bool = False
def __post_init__(self):
super().__init__()
self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment