Commit 3b8a33e9 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

store original declared types in Configurable

Summary: Aid reflection by adding the original declared types of replaced members of a configurable as values in _processed_members.

Reviewed By: davnov134

Differential Revision: D35358422

fbshipit-source-id: 80ef3266144c51c1c2105f349e0dd3464e230429
parent 199309fc
...@@ -11,7 +11,7 @@ import itertools ...@@ -11,7 +11,7 @@ import itertools
import warnings import warnings
from collections import Counter, defaultdict from collections import Counter, defaultdict
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch3d.common.datatypes import get_args, get_origin from pytorch3d.common.datatypes import get_args, get_origin
...@@ -635,7 +635,9 @@ def expand_args_fields( ...@@ -635,7 +635,9 @@ def expand_args_fields(
- _known_implementations: Dict[str, Type] containing the classes which - _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry. have been found from the registry.
(used only to raise a warning if it one has been overwritten) (used only to raise a warning if it one has been overwritten)
- _processed_members: a Set[str] of all the members which have been transformed. - _processed_members: a Dict[str, Any] of all the members which have been
transformed, with values giving the types they were declared to have.
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
Args: Args:
some_class: the class to be processed some_class: the class to be processed
...@@ -660,7 +662,7 @@ def expand_args_fields( ...@@ -660,7 +662,7 @@ def expand_args_fields(
# unused. # unused.
known_implementations: Dict[str, Type] = {} known_implementations: Dict[str, Type] = {}
# Names of members which have been processed. # Names of members which have been processed.
processed_members: Set[str] = set() processed_members: Dict[str, Any] = {}
# For all bases except ReplaceableBase and Configurable and object, # For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is # we need to process them before our own processing. This is
...@@ -691,6 +693,7 @@ def expand_args_fields( ...@@ -691,6 +693,7 @@ def expand_args_fields(
to_process.append((name, underlying_type, process_type)) to_process.append((name, underlying_type, process_type))
for name, underlying_type, process_type in to_process: for name, underlying_type, process_type in to_process:
processed_members[name] = some_class.__annotations__[name]
_process_member( _process_member(
name=name, name=name,
type_=underlying_type, type_=underlying_type,
...@@ -700,7 +703,6 @@ def expand_args_fields( ...@@ -700,7 +703,6 @@ def expand_args_fields(
_do_not_process=_do_not_process, _do_not_process=_do_not_process,
known_implementations=known_implementations, known_implementations=known_implementations,
) )
processed_members.add(name)
for key, count in Counter(creation_functions).items(): for key, count in Counter(creation_functions).items():
if count > 1: if count > 1:
......
...@@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase): ...@@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase):
container_args = get_default_args(Container) container_args = get_default_args(Container)
container = Container(**container_args) container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange) self.assertIsInstance(container.fruit, Orange)
self.assertEqual(Container._processed_members, {"fruit": Fruit})
self.assertEqual(container._processed_members, {"fruit": Fruit})
container_defaulted = Container() container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4 container_defaulted.fruit_Pear_args.n_pips += 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment