"tests/vscode:/vscode.git/clone" did not exist on "59d2f7ac2385f20105513cdc76010f996f731af0"
Commit 5046b86b authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 371389966
parent 6b695ca6
...@@ -41,6 +41,26 @@ _PARAM_RE = re.compile( ...@@ -41,6 +41,26 @@ _PARAM_RE = re.compile(
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)') _CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object): class ParamsDict(object):
"""A hyperparameter container class.""" """A hyperparameter container class."""
...@@ -309,10 +329,10 @@ class ParamsDict(object): ...@@ -309,10 +329,10 @@ class ParamsDict(object):
raise ValueError('Unsupported relation in restriction.') raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path): def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict.""" """Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=yaml.SafeLoader) params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict) return ParamsDict(params_dict)
...@@ -433,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): ...@@ -433,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError: except ValueError:
pass pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=yaml.FullLoader) params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict): if isinstance(params_dict, dict):
params.override(params_dict, is_strict) params.override(params_dict, is_strict)
else: else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict) params.override(yaml.load(f, Loader=LOADER), is_strict)
else: else:
raise ValueError('Unknown input type to parse.') raise ValueError('Unknown input type to parse.')
return params return params
...@@ -321,6 +321,14 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -321,6 +321,14 @@ class ParamsDictIOTest(tf.test.TestCase):
self.assertEqual([3, 4], params.b.b2) self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2) self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e) self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self): def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment