Commit e9e6d17c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support constant argument in restriction checking in params_dict.

PiperOrigin-RevId: 310431197
parent b39ce6bb
......@@ -42,6 +42,10 @@ _PARAM_RE = re.compile(r"""
\[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE)
# pylint: disable=anomalous-backslash-in-string
_CONST_VALUE_RE = re.compile('(\d.*|-\d.*|None)')
# pylint: enable=anomalous-backslash-in-string
class ParamsDict(object):
"""A hyperparameter container class."""
......@@ -239,6 +243,15 @@ class ParamsDict(object):
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None:
const_str = dotted_string
if const_str == 'None':
constant = None
else:
constant = float(const_str)
return None, constant
else:
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
......
......@@ -155,6 +155,14 @@ class ParamsDictTest(tf.test.TestCase):
with self.assertRaises(KeyError):
params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict(
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
class ParamsDictIOTest(tf.test.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment