Commit 76b8a67a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340481326
parent fd67e121
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import copy import copy
import functools import functools
from typing import Any, List, Mapping, Optional, Type from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -155,13 +156,30 @@ class Config(params_dict.ParamsDict): ...@@ -155,13 +156,30 @@ class Config(params_dict.ParamsDict):
RuntimeError RuntimeError
""" """
subconfig_type = self._get_subconfig_type(k) subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]: if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object # If the key not exist or the value is None, a new Config-family object
# sould be created for the key. # sould be created for the key.
self.__dict__[k] = subconfig_type(v) self.__dict__[k] = subconfig_type(v)
else: else:
self.__dict__[k].override(v) self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else: else:
self.__dict__[k] = self._import_config(v, subconfig_type) self.__dict__[k] = self._import_config(v, subconfig_type)
......
...@@ -44,6 +44,17 @@ class DumpConfig3(DumpConfig2): ...@@ -44,6 +44,17 @@ class DumpConfig3(DumpConfig2):
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),) g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
@dataclasses.dataclass
class DumpConfig4(DumpConfig2):
x: int = 3
@dataclasses.dataclass
class DummyConfig5(base_config.Config):
y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
z: Tuple[str] = ('a',)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''): def assertHasSameTypes(self, c, d, msg=''):
...@@ -162,10 +173,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -162,10 +173,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True) params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]}) config = base_config.Config({'key': [{'a': 42}]})
with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
config.override({'key': [{'b': 43}]}) config.override({'key': [{'b': 43}]})
self.assertEqual(config.key[0].b, 43)
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
_ = config.key[0].a
@parameterized.parameters( @parameterized.parameters(
(lambda x: x, 'Unknown type'), (lambda x: x, 'Unknown type'),
...@@ -314,6 +323,30 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -314,6 +323,30 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
config = DumpConfig2(restrictions=restrictions) config = DumpConfig2(restrictions=restrictions)
config.validate() config.validate()
def test_nested_tuple(self):
config = DummyConfig5()
config.override({
'y': [{
'c': 4,
'd': 'new text 3',
'e': {
'a': 2
}
}, {
'c': 0,
'd': 'new text 3',
'e': {
'a': 2
}
}],
'z': ['a', 'b', 'c'],
})
self.assertEqual(config.y[0].c, 4)
self.assertEqual(config.y[1].c, 0)
self.assertIsInstance(config.y[0], DumpConfig2)
self.assertIsInstance(config.y[1], DumpConfig4)
self.assertSameElements(config.z, ['a', 'b', 'c'])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment