Commit 5046b86b authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 371389966
parent 6b695ca6
......@@ -41,6 +41,26 @@ _PARAM_RE = re.compile(
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object):
"""A hyperparameter container class."""
......@@ -309,10 +329,10 @@ class ParamsDict(object):
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path):
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=yaml.SafeLoader)
params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict)
......@@ -433,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=yaml.FullLoader)
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
params.override(yaml.load(f, Loader=LOADER), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
......@@ -321,6 +321,14 @@ class ParamsDictIOTest(tf.test.TestCase):
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment