Commit 90ab219d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

clarify expand_args_fields

Summary: Fix doc and add a call to expand_args_fields for each implicit function.

Reviewed By: shapovalov

Differential Revision: D35929811

fbshipit-source-id: 8c3cfa56b8d8908fd2165614960e3d34b54717bb
parent 9e57b994
...@@ -15,6 +15,7 @@ import torch ...@@ -15,6 +15,7 @@ import torch
import tqdm import tqdm
from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry, registry,
run_auto_creation, run_auto_creation,
) )
...@@ -677,6 +678,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ...@@ -677,6 +678,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
implicit_function_type = registry.get( implicit_function_type = registry.get(
ImplicitFunctionBase, self.implicit_function_class_type ImplicitFunctionBase, self.implicit_function_class_type
) )
expand_args_fields(implicit_function_type)
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes(): if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
raise ValueError( raise ValueError(
self.implicit_function_class_type self.implicit_function_class_type
......
...@@ -606,12 +606,14 @@ def expand_args_fields( ...@@ -606,12 +606,14 @@ def expand_args_fields(
""" """
This expands a class which inherits Configurable or ReplaceableBase classes, This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function. including dataclass processing. some_class is modified in place by this function.
If expand_args_fields(some_class) has already been called, subsequent calls do
nothing and return some_class unmodified.
For classes of type ReplaceableBase, you can add some_class to the registry before For classes of type ReplaceableBase, you can add some_class to the registry before
or after calling this function. But potential inner classes need to be registered or after calling this function. But potential inner classes need to be registered
before this function is run on the outer class. before this function is run on the outer class.
The transformations this function makes, before the concluding The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered dataclasses.dataclass, are as follows. If X is a base class with registered
subclasses Y and Z, replace a class member subclasses Y and Z, replace a class member
x: X x: X
...@@ -626,7 +628,9 @@ def expand_args_fields( ...@@ -626,7 +628,9 @@ def expand_args_fields(
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self): def create_x(self):
self.x = registry.get(X, self.x_class_type)( x_type = registry.get(X, self.x_class_type)
expand_args_fields(x_type)
self.x = x_type(
**self.getattr(f"x_{self.x_class_type}_args) **self.getattr(f"x_{self.x_class_type}_args)
) )
x_class_type: str = "UNDEFAULTED" x_class_type: str = "UNDEFAULTED"
...@@ -651,7 +655,9 @@ def expand_args_fields( ...@@ -651,7 +655,9 @@ def expand_args_fields(
self.x = None self.x = None
return return
self.x = registry.get(X, self.x_class_type)( x_type = registry.get(X, self.x_class_type)
expand_args_fields(x_type)
self.x = x_type(
**self.getattr(f"x_{self.x_class_type}_args) **self.getattr(f"x_{self.x_class_type}_args)
) )
x_class_type: Optional[str] = "UNDEFAULTED" x_class_type: Optional[str] = "UNDEFAULTED"
...@@ -670,6 +676,7 @@ def expand_args_fields( ...@@ -670,6 +676,7 @@ def expand_args_fields(
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self): def create_x(self):
expand_args_fields(X)
self.x = X(self.x_args) self.x = X(self.x_args)
Similarly, replace, Similarly, replace,
...@@ -687,6 +694,7 @@ def expand_args_fields( ...@@ -687,6 +694,7 @@ def expand_args_fields(
x_enabled: bool = False x_enabled: bool = False
def create_x(self): def create_x(self):
if self.x_enabled: if self.x_enabled:
expand_args_fields(X)
self.x = X(self.x_args) self.x = X(self.x_args)
else: else:
self.x = None self.x = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment