Commit 1b6e6dba authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340481326
parent 065da0fc
......@@ -17,6 +17,7 @@
import copy
import functools
from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses
import tensorflow as tf
......@@ -155,13 +156,30 @@ class Config(params_dict.ParamsDict):
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
......
......@@ -44,6 +44,17 @@ class DumpConfig3(DumpConfig2):
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
@dataclasses.dataclass
class DumpConfig4(DumpConfig2):
x: int = 3
@dataclasses.dataclass
class DummyConfig5(base_config.Config):
y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
z: Tuple[str] = ('a',)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''):
......@@ -162,10 +173,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]})
config.override({'key': [{'b': 43}]})
self.assertEqual(config.key[0].b, 43)
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
_ = config.key[0].a
with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
config.override({'key': [{'b': 43}]})
@parameterized.parameters(
(lambda x: x, 'Unknown type'),
......@@ -314,6 +323,30 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
config = DumpConfig2(restrictions=restrictions)
config.validate()
def test_nested_tuple(self):
config = DummyConfig5()
config.override({
'y': [{
'c': 4,
'd': 'new text 3',
'e': {
'a': 2
}
}, {
'c': 0,
'd': 'new text 3',
'e': {
'a': 2
}
}],
'z': ['a', 'b', 'c'],
})
self.assertEqual(config.y[0].c, 4)
self.assertEqual(config.y[1].c, 0)
self.assertIsInstance(config.y[0], DumpConfig2)
self.assertIsInstance(config.y[1], DumpConfig4)
self.assertSameElements(config.z, ['a', 'b', 'c'])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment