Commit 9c1887a8 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

ParamsDict update.

PiperOrigin-RevId: 307639416
parent 5e539a3d
......@@ -125,6 +125,25 @@ class ParamsDict(object):
"""Accesses through built-in dictionary get method."""
return self.__dict__.get(key, value)
def __delattr__(self, k):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if k in ParamsDict.RESERVED_ATTR:
raise AttributeError('The key `{}` is reserved. No change is allowes. '
.format(k))
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
del self.__dict__[k]
def override(self, override_params, is_strict=True):
"""Override the ParamsDict with a set of given params.
......@@ -286,7 +305,6 @@ def read_yaml_to_params_dict(file_path):
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests for official.modeling.hyperparams.params_dict.py."""
"""Tests for params_dict.py."""
import os
......@@ -45,12 +45,14 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.b, 2)
def test_lock(self):
params = params_dict.ParamsDict({'a': 1, 'b': 2})
params = params_dict.ParamsDict({'a': 1, 'b': 2, 'c': 3})
params.lock()
with self.assertRaises(ValueError):
params.a = 10
with self.assertRaises(ValueError):
params.override({'b': 20})
with self.assertRaises(ValueError):
del params.c
def test_setattr(self):
params = params_dict.ParamsDict()
......@@ -69,6 +71,20 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_delattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}},
is_strict=False)
del params.c
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
with self.assertRaises(AttributeError):
_ = params.c
del params.d
with self.assertRaises(AttributeError):
_ = params.d.d1
def test_contains(self):
params = params_dict.ParamsDict()
params.override(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment