Commit 9540c290 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Make Module.__init__ automatic

Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
parent 97f8f9bf
......@@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
random_sampling: bool
add_input_samples: bool = True
def __post_init__(self) -> None:
super().__init__()
def forward(
self,
input_ray_bundle: ImplicitronRayBundle,
......
......@@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
Base class for ray samplers.
"""
def __init__(self):
super().__init__()
def forward(
self,
cameras: CamerasBase,
......@@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
stratified_point_sampling_evaluation: bool = False
def __post_init__(self):
super().__init__()
if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None
):
......
......@@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
n_steps: int = 100
n_secant_steps: int = 8
def __post_init__(self):
super().__init__()
def forward(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
......
......@@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
and marching along them in order to generate a feature render.
"""
def __init__(self):
super().__init__()
def forward(
self,
rays_densities: torch.Tensor,
......@@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super().__init__()
bg_color = torch.tensor(self.bg_color)
if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
......
......@@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
def __post_init__(
self,
):
super().__init__()
render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError(
......
......@@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the outputs.
"""
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
......@@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
ReductionFunction.STD,
)
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
......@@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
......@@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
......
......@@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
feature_aggregator: FeatureAggregatorBase
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
......
......@@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
masked_sampling: bool = False
sampling_mode: str = "bilinear"
def __post_init__(self):
super().__init__()
def forward(
self,
*, # force kw args
......
......@@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args"
_DATACLASS_INIT: str = "__dataclass_own_init__"
class ReplaceableBase:
......@@ -834,6 +835,9 @@ def expand_args_fields(
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
In addition, if the class inherits torch.nn.Module, the generated __init__ will
call torch.nn.Module's __init__ before doing anything else.
Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
......@@ -912,9 +916,40 @@ def expand_args_fields(
some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class)
_fixup_class_init(some_class)
return some_class
def _fixup_class_init(some_class) -> None:
"""
In-place modification of the some_class class which happens
after dataclass processing.
If the dataclass some_class inherits torch.nn.Module, then
makes torch.nn.Module's __init__ be called before anything else
on instantiation of some_class.
This is a bit like attr's __pre_init__.
"""
assert _is_actually_dataclass(some_class)
try:
import torch
except ModuleNotFoundError:
return
if not issubclass(some_class, torch.nn.Module):
return
def init(self, *args, **kwargs) -> None:
torch.nn.Module.__init__(self)
getattr(self, _DATACLASS_INIT)(*args, **kwargs)
assert not hasattr(some_class, _DATACLASS_INIT)
setattr(some_class, _DATACLASS_INIT, some_class.__init__)
some_class.__init__ = init
def get_default_args_field(
C,
*,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment