Commit 9540c290 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Make Module.__init__ automatic

Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
parent 97f8f9bf
...@@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module): ...@@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
random_sampling: bool random_sampling: bool
add_input_samples: bool = True add_input_samples: bool = True
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
input_ray_bundle: ImplicitronRayBundle, input_ray_bundle: ImplicitronRayBundle,
......
...@@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase): ...@@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
Base class for ray samplers. Base class for ray samplers.
""" """
def __init__(self):
super().__init__()
def forward( def forward(
self, self,
cameras: CamerasBase, cameras: CamerasBase,
...@@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): ...@@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
stratified_point_sampling_evaluation: bool = False stratified_point_sampling_evaluation: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
if (self.n_rays_per_image_sampled_from_mask is not None) and ( if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None self.n_rays_total_training is not None
): ):
......
...@@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module): ...@@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
n_steps: int = 100 n_steps: int = 100
n_secant_steps: int = 8 n_secant_steps: int = 8
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
sdf: Callable[[torch.Tensor], torch.Tensor], sdf: Callable[[torch.Tensor], torch.Tensor],
......
...@@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase): ...@@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
and marching along them in order to generate a feature render. and marching along them in order to generate a feature render.
""" """
def __init__(self):
super().__init__()
def forward( def forward(
self, self,
rays_densities: torch.Tensor, rays_densities: torch.Tensor,
...@@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module): ...@@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
surface_thickness: Denotes the overlap between the absorption surface_thickness: Denotes the overlap between the absorption
function and the density function. function and the density function.
""" """
super().__init__()
bg_color = torch.tensor(self.bg_color) bg_color = torch.tensor(self.bg_color)
if bg_color.ndim != 1: if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
......
...@@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign ...@@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
def __post_init__( def __post_init__(
self, self,
): ):
super().__init__()
render_features_dimensions = self.render_features_dimensions render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]: if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError( raise ValueError(
......
...@@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): ...@@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the outputs. the outputs.
""" """
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
...@@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): ...@@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
ReductionFunction.STD, ReductionFunction.STD,
) )
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
...@@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator ...@@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
weight_by_ray_angle_gamma: float = 1.0 weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1 min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
...@@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB ...@@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
weight_by_ray_angle_gamma: float = 1.0 weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1 min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
......
...@@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module): ...@@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
feature_aggregator: FeatureAggregatorBase feature_aggregator: FeatureAggregatorBase
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
......
...@@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module): ...@@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
masked_sampling: bool = False masked_sampling: bool = False
sampling_mode: str = "bilinear" sampling_mode: str = "bilinear"
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
*, # force kw args *, # force kw args
......
...@@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled" ...@@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
CREATE_PREFIX: str = "create_" CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl" IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args" TWEAK_SUFFIX: str = "_tweak_args"
_DATACLASS_INIT: str = "__dataclass_own_init__"
class ReplaceableBase: class ReplaceableBase:
...@@ -834,6 +835,9 @@ def expand_args_fields( ...@@ -834,6 +835,9 @@ def expand_args_fields(
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args). the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
In addition, if the class inherits torch.nn.Module, the generated __init__ will
call torch.nn.Module's __init__ before doing anything else.
Note that although the *_args members are intended to have type DictConfig, they Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify in place of a dict, but not vice-versa. Allowing dict lets a class user specify
...@@ -912,9 +916,40 @@ def expand_args_fields( ...@@ -912,9 +916,40 @@ def expand_args_fields(
some_class._known_implementations = known_implementations some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class) dataclasses.dataclass(eq=False)(some_class)
_fixup_class_init(some_class)
return some_class return some_class
def _fixup_class_init(some_class) -> None:
"""
In-place modification of the some_class class which happens
after dataclass processing.
If the dataclass some_class inherits torch.nn.Module, then
makes torch.nn.Module's __init__ be called before anything else
on instantiation of some_class.
This is a bit like attr's __pre_init__.
"""
assert _is_actually_dataclass(some_class)
try:
import torch
except ModuleNotFoundError:
return
if not issubclass(some_class, torch.nn.Module):
return
def init(self, *args, **kwargs) -> None:
torch.nn.Module.__init__(self)
getattr(self, _DATACLASS_INIT)(*args, **kwargs)
assert not hasattr(some_class, _DATACLASS_INIT)
setattr(some_class, _DATACLASS_INIT, some_class.__init__)
some_class.__init__ = init
def get_default_args_field( def get_default_args_field(
C, C,
*, *,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment