Commit e9e6d17c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support constant argument in restriction checking in params_dict.

PiperOrigin-RevId: 310431197
parent b39ce6bb
...@@ -42,6 +42,10 @@ _PARAM_RE = re.compile(r""" ...@@ -42,6 +42,10 @@ _PARAM_RE = re.compile(r"""
\[[^\]]*\])) # list of values \[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE) ($|,\s*)""", re.VERBOSE)
# pylint: disable=anomalous-backslash-in-string
_CONST_VALUE_RE = re.compile('(\d.*|-\d.*|None)')
# pylint: enable=anomalous-backslash-in-string
class ParamsDict(object): class ParamsDict(object):
"""A hyperparameter container class.""" """A hyperparameter container class."""
...@@ -239,11 +243,20 @@ class ParamsDict(object): ...@@ -239,11 +243,20 @@ class ParamsDict(object):
ValueError: if the restriction defined in the string is not supported. ValueError: if the restriction defined in the string is not supported.
""" """
def _get_kv(dotted_string, params_dict): def _get_kv(dotted_string, params_dict):
tokenized_params = dotted_string.split('.') """Get keys and values indicated by dotted_string."""
v = params_dict if _CONST_VALUE_RE.match(dotted_string) is not None:
for t in tokenized_params: const_str = dotted_string
v = v[t] if const_str == 'None':
return tokenized_params[-1], v constant = None
else:
constant = float(const_str)
return None, constant
else:
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
v = v[t]
return tokenized_params[-1], v
def _get_kvs(tokens, params_dict): def _get_kvs(tokens, params_dict):
if len(tokens) != 2: if len(tokens) != 2:
......
...@@ -155,6 +155,14 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -155,6 +155,14 @@ class ParamsDictTest(tf.test.TestCase):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params.validate() params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict(
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
class ParamsDictIOTest(tf.test.TestCase): class ParamsDictIOTest(tf.test.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment