Unverified Commit 8625efd8 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7393)

262004398  by taylorrobie<taylorrobie@google.com>:

    Internal change

PiperOrigin-RevId: 262004398
parent 4d93d894
......@@ -110,7 +110,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self._set_hyper('weight_decay_rate', weight_decay_rate)
self.weight_decay_rate = weight_decay_rate
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
......@@ -120,12 +120,18 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return super(AdamWeightDecay, cls).from_config(
config, custom_objects=custom_objects)
def _decay_weights_op(self, var, learning_rate):
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
apply_state)
apply_state['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate')
def _decay_weights_op(self, var, learning_rate, apply_state):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var *
self._get_hyper('weight_decay_rate'),
apply_state['weight_decay_rate'],
use_locking=self._use_locking)
return tf.no_op()
......@@ -149,26 +155,29 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices, **kwargs)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update({
'weight_decay_rate':
self._serialize_hyperparameter('weight_decay_rate'),
'weight_decay_rate': self.weight_decay_rate,
})
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import re
import six
import tensorflow as tf
import yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE)
class ParamsDict(object):
"""A hyperparameter container class."""
RESERVED_ATTR = ['_locked', '_restrictions']
def __init__(self, default_params=None, restrictions=None):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self._locked = False
self._restrictions = []
if restrictions:
self._restrictions = restrictions
if default_params is None:
default_params = {}
self.override(default_params, is_strict=False)
self.validate()
def _set(self, k, v):
if isinstance(v, dict):
self.__dict__[k] = ParamsDict(v)
else:
self.__dict__[k] = copy.deepcopy(v)
def __setattr__(self, k, v):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if k not in ParamsDict.RESERVED_ATTR:
if k not in self.__dict__.keys():
raise KeyError('The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. '
'No change is allowed.')
self._set(k, v)
def __getattr__(self, k):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if k not in self.__dict__.keys():
raise KeyError('The key `{}` does not exist. '.format(k))
return self.__dict__[k]
def override(self, override_params, is_strict=True):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If False, keys in `override_params` can be different from what is
currently defined in the ParamsDict. In this case, the ParamsDict will
be extended to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
if isinstance(override_params, ParamsDict):
override_params = override_params.as_dict()
self._override(override_params, is_strict) # pylint: disable=protected-access
def _override(self, override_dict, is_strict=True):
"""The implementation of `override`."""
for k, v in six.iteritems(override_dict):
if k in ParamsDict.RESERVED_ATTR:
raise KeyError('The key `%{}` is internally reserved. '
'Can not be overridden.')
if k not in self.__dict__.keys():
if is_strict:
raise KeyError('The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k))
else:
self._set(k, v)
else:
if isinstance(v, dict):
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, ParamsDict):
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self.__dict__[k] = copy.deepcopy(v)
def lock(self):
"""Makes the ParamsDict immutable."""
self._locked = True
def as_dict(self):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict = {}
for k, v in six.iteritems(self.__dict__):
if k not in ParamsDict.RESERVED_ATTR:
if isinstance(v, ParamsDict):
params_dict[k] = v.as_dict()
else:
params_dict[k] = copy.deepcopy(v)
return params_dict
def validate(self):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 2
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
v = v[t]
return tokenized_params[-1], v
def _get_kvs(tokens, params_dict):
if len(tokens) != 2:
raise ValueError('Only support binary relation in restriction.')
stripped_tokens = [t.strip() for t in tokens]
left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
return left_k, left_v, right_k, right_v
params_dict = self.as_dict()
for restriction in self._restrictions:
if '==' in restriction:
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f)
return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
def nested_csv_str_to_json_str(csv_str):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if not csv_str:
return ''
formatted_entries = []
nested_map = collections.defaultdict(list)
pos = 0
while pos < len(csv_str):
m = _PARAM_RE.match(csv_str, pos)
if not m:
raise ValueError('Malformed hyperparameter value while parsing '
'CSV string: %s' % csv_str[pos:])
pos = m.end()
# Parse the values.
m_dict = m.groupdict()
name = m_dict['name']
v = m_dict['val']
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if re.match(r'(?=[^\"\'])(?=[gs://])', v):
v = '\'{}\''.format(v)
name_nested = name.split('.')
if len(name_nested) > 1:
grouping = name_nested[0]
value = '.'.join(name_nested[1:]) + '=' + v
nested_map[grouping].append(value)
else:
formatted_entries.append('%s : %s' % (name, v))
for grouping, value in nested_map.items():
value = ','.join(value)
value = nested_csv_str_to_json_str(value)
formatted_entries.append('%s : %s' % (grouping, value))
return '{' + ', '.join(formatted_entries) + '}'
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if not dict_or_string_or_yaml_file:
return params
if isinstance(dict_or_string_or_yaml_file, dict):
params.override(dict_or_string_or_yaml_file, is_strict)
elif isinstance(dict_or_string_or_yaml_file, six.string_types):
try:
dict_or_string_or_yaml_file = (
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.modeling.hyperparams.params_dict.py."""
import os
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
class ParamsDictTest(tf.test.TestCase):
def test_init_from_an_empty_dict(self):
params = params_dict.ParamsDict()
with self.assertRaises(KeyError):
_ = params.a
with self.assertRaises(KeyError):
params.a = 'aa'
def test_init_from_a_dict(self):
params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_init_from_a_param_dict(self):
params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
params = params_dict.ParamsDict(params_init)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_lock(self):
params = params_dict.ParamsDict({'a': 1, 'b': 2})
params.lock()
with self.assertRaises(ValueError):
params.a = 10
with self.assertRaises(ValueError):
params.override({'b': 20})
def test_setattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, 'ccc')
def test_getattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
params.override({'d': 'ddd'}, is_strict=False)
self.assertEqual(params.d, 'ddd')
params.override({'c': {'c4': 4444}}, is_strict=False)
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
self.assertEqual(params_d['c']['c1'], 10)
self.assertEqual(params_d['c']['c2'], 20)
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 11}}, ['a == c'])
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
# Valid rule.
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
with self.assertRaises(KeyError):
params.validate()
class ParamsDictIOTest(tf.test.TestCase):
def write_temp_file(self, filename, text):
temp_file = os.path.join(self.get_temp_dir(), filename)
with tf.io.gfile.GFile(temp_file, 'w') as writer:
writer.write(text)
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
params_d = yaml.load(f)
self.assertEqual(params.a, params_d['a'])
self.assertEqual(params.b, params_d['b'])
self.assertEqual(params.c.c1, params_d['c']['c1'])
self.assertEqual(params.c.c2, params_d['c']['c2'])
def test_read_yaml_to_params_dict(self):
input_yaml_file = self.write_temp_file(
'params.yaml', r"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
""")
params = params_dict.read_yaml_to_params_dict(input_yaml_file)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c.c1, 10)
self.assertEqual(params.c.c2, 20)
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi', params.d.d1.d2)
self.assertEqual(False, params.e)
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
c: [30, 40]
""")
params = params_dict.override_params_dict(
params, override_yaml_file, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
class IOTest(tf.test.TestCase):
def test_basic_csv_str_to_json_str(self):
csv_str = 'a=1,b=2,c=3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_csv_str_load(self):
csv_str = 'a=1,b=2,c=3'
expected_output = {'a': 1, 'b': 2, 'c': 3}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_basic_nested_csv_str_to_json_str(self):
csv_str = 'a=1,b.b1=2'
json_str = '{a : 1, b : {b1 : 2}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_nested_csv_str_load(self):
csv_str = 'a=1,b.b1=2,c.c1=3'
expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_complex_nested_csv_str_to_json_str(self):
csv_str = 'a.aa.aaa.aaaaa.a=1'
json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_complex_nested_csv_str_load(self):
csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_csv_str_load_supported_datatypes(self):
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertEqual(converted_dict['a'], 1)
self.assertEqual(converted_dict['b'], 2.)
self.assertEqual(converted_dict['c'], [1, 2, 3])
self.assertEqual(converted_dict['d'], 'hello, there')
self.assertEqual(converted_dict['e'], 'Hi.')
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError,
params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
csv_str1 = 'a=1,b=2,c=3'
csv_str2 = 'a = 1, b = 2, c = 3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
self.assertEqual(converted_csv_str1, converted_csv_str2)
self.assertEqual(converted_csv_str1, json_str)
self.assertEqual(converted_csv_str2, json_str)
def test_gcs_added_quotes(self):
csv_str = 'a=gs://abc, b=gs://def'
expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, expected_output)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment