Commit 24260130 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

_allow_untyped for get_default_args

Summary:
ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1144

Reviewed By: shapovalov

Differential Revision: D35258561

fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
parent a54ad2b9
...@@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils ...@@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ( from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData, FrameData,
ImplicitronDataset,
) )
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
...@@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"): if version.parse(hydra.__version__) < version.Version("1.1"):
...@@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int): ...@@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int):
@dataclass(eq=False) @dataclass(eq=False)
class ExperimentConfig: class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel) generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer) solver_args: DictConfig = get_default_args_field(
init_optimizer, _allow_untyped=True
)
dataset_args: DictConfig = get_default_args_field(dataset_zoo) dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic" architecture: str = "generic"
......
...@@ -57,6 +57,7 @@ def dataloader_zoo( ...@@ -57,6 +57,7 @@ def dataloader_zoo(
`"dataset_subset_name": torch_dataloader_object` key, value pairs. `"dataset_subset_name": torch_dataloader_object` key, value pairs.
""" """
images_per_seq_options = tuple(images_per_seq_options)
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]: if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}") raise ValueError(f"Unsupported dataset: {dataset_name}")
......
...@@ -100,6 +100,8 @@ def dataset_zoo( ...@@ -100,6 +100,8 @@ def dataset_zoo(
datasets: A dictionary containing the datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs. `"dataset_subset_name": torch_dataset_object` key, value pairs.
""" """
restrict_sequence_name = tuple(restrict_sequence_name)
aux_dataset_kwargs = dict(aux_dataset_kwargs)
datasets = {} datasets = {}
......
...@@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork ...@@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register @registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3 render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing) ray_tracer_args: DictConfig = get_default_args_field(
RayTracing, _allow_untyped=True
)
ray_normal_coloring_network_args: DictConfig = get_default_args_field( ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork RayNormalColoringNetwork, _allow_untyped=True
) )
bg_color: Tuple[float, ...] = (0.0,) bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0 soft_mask_alpha: float = 50.0
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import dataclasses import dataclasses
import inspect import inspect
import itertools import itertools
...@@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool: ...@@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase)) return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig: def get_default_args(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
) -> DictConfig:
""" """
Get the DictConfig of args to call C - which might be a type or a function. Get the DictConfig of args to call C - which might be a type or a function.
...@@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig ...@@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
Args: Args:
C: the class or function to be processed C: the class or function to be processed
_allow_untyped: (internal use) If True, do not try to make the
output typed when it is not a Configurable or
ReplaceableBase. This avoids problems (due to local
dataclasses being remembered inside the returned
DictConfig and any of its DictConfig and ListConfig
members) when pickling the output, but will break
conversions of yaml strings to/from any emum members
of C.
_do_not_process: (internal use) When this function is called from _do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class processed, to make sure we don't try to process a class
...@@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig ...@@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# regular class or function # regular class or function
field_annotations = [] field_annotations = []
kwargs = {}
for pname, defval in _params_iter(C): for pname, defval in _params_iter(C):
default = defval.default default = defval.default
if default == inspect.Parameter.empty: if default == inspect.Parameter.empty:
...@@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig ...@@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
_, annotation = _resolve_optional(defval.annotation) _, annotation = _resolve_optional(defval.annotation)
kwargs[pname] = copy.deepcopy(default)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default) default = tuple(default)
...@@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig ...@@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
field_ = dataclasses.field(default=default) field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_)) field_annotations.append((pname, defval.annotation, field_))
if _allow_untyped:
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it. # make a temp dataclass and generate a structured config from it.
return OmegaConf.structured( return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations) dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
...@@ -696,7 +713,9 @@ def expand_args_fields( ...@@ -696,7 +713,9 @@ def expand_args_fields(
return some_class return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): def get_default_args_field(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
):
""" """
Get a dataclass field which defaults to get_default_args(...) Get a dataclass field which defaults to get_default_args(...)
...@@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): ...@@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
""" """
def create(): def create():
return get_default_args(C, _do_not_process=_do_not_process) return get_default_args(
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
)
return dataclasses.field(default_factory=create) return dataclasses.field(default_factory=create)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import pickle
import textwrap import textwrap
import unittest import unittest
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
...@@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase): ...@@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase):
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1) self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
def test_remove_unused_components(self): def test_remove_unused_components(self):
struct = get_default_args(MainTest) struct = get_default_args(MainTest)
struct.n_ids = 32 struct.n_ids = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment