Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -171,26 +171,30 @@ class Task(tf.Module):
return []
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API.
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API.
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, model_outputs)
......@@ -297,4 +301,3 @@ class Task(tf.Module):
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
......@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import gin
import orbit
import tensorflow as tf
......@@ -28,7 +29,6 @@ from official.modeling import optimization
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
ExperimentConfig = config_definitions.ExperimentConfig
......@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
model: tf.keras.Model instance. If provided, it will be used instead
of building model using task.build_model(). Default to None.
model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
"""
......@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, model=self.model,
optimizer=self.optimizer, **checkpoint_items)
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
......
......@@ -15,6 +15,7 @@
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized
import tensorflow as tf
......@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
self._config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig(
{'optimizer': {
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}})))
}
})))
def create_test_trainer(self):
task = mock_task.MockTask()
......@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig(
{'optimizer': {
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}})))
}
})))
task = mock_task.MockTask()
trainer = trainer_lib.Trainer(config, task)
if mixed_precision_dtype != 'float16':
......
......@@ -136,7 +136,8 @@ class Config(params_dict.ParamsDict):
return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params,
super().__init__(
default_params=default_params,
restrictions=restrictions,
*args,
**kwargs)
......
......@@ -55,14 +55,14 @@ class DataConfig(base_config.Config):
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
"""
input_path: str = ""
tfds_name: str = ""
......@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
......@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config):
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
......@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config):
if self.type is None:
return {'type': None}
elif self.__dict__['type'] not in self.__dict__:
raise ValueError(
'type: {!r} is not a valid key!'.format(self.__dict__['type']))
raise ValueError('type: {!r} is not a valid key!'.format(
self.__dict__['type']))
else:
chosen_type = self.type
chosen_value = self.__dict__[chosen_type]
return {
'type': self.type,
chosen_type: self._export_config(chosen_value)
}
return {'type': self.type, chosen_type: self._export_config(chosen_value)}
def get(self):
"""Returns selected config based on the value of type.
......@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config):
if chosen_type is None:
return None
if chosen_type not in self.__dict__:
raise ValueError(
'type: {!r} is not a valid key!'.format(self.type))
raise ValueError('type: {!r} is not a valid key!'.format(self.type))
return self.__dict__[chosen_type]
......@@ -48,11 +48,17 @@ class Network(base_config.Config):
class OneOfTest(tf.test.TestCase):
def test_to_dict(self):
network_params = {'backbone': {'type': 'resnet',
'resnet': {'model_depth': 50}
network_params = {
'backbone': {
'type': 'resnet',
'resnet': {
'model_depth': 50
}
},
'output_layer': {'type': 'single',
'single': 1000}
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params)
......
......@@ -30,7 +30,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r"""
_PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
......@@ -138,8 +139,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
"""
if k in ParamsDict.RESERVED_ATTR:
raise AttributeError('The key `{}` is reserved. No change is allowes. '
.format(k))
raise AttributeError(
'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked:
......@@ -150,13 +151,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be overridden.
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If False, keys in `override_params` can be different from what is
currently defined in the ParamsDict. In this case, the ParamsDict will
be extended to include the new keys.
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
......@@ -240,6 +241,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None:
......@@ -270,38 +272,44 @@ class ParamsDict(object):
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
......@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path):
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
......@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to a YAML file specifying the parameters to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for params_dict.py."""
import os
......@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def test_setattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
......@@ -65,16 +63,22 @@ class ParamsDictTest(tf.test.TestCase):
def test_getattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_delattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}},
params.override({
'a': 'aa',
'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c
self.assertEqual(params.a, 'aa')
......@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def test_contains(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
params.override({'a': 'aa'}, is_strict=False)
self.assertIn('a', params)
self.assertNotIn('b', params)
def test_get(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
params.override({'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
......@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
......@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
......@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase):
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 11}}, ['a == c'])
params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
params = params_dict.ParamsDict({
'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
# Valid rule.
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
......@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase):
params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict(
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params = params_dict.ParamsDict({
'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params = params_dict.ParamsDict({
'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
class ParamsDictIOTest(tf.test.TestCase):
......@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
......@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
......@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
......@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
......@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
......@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
......@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase):
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError,
params_dict.nested_csv_str_to_json_str,
self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
......
......@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries.
Defaults to None.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
[0, boundaries[0]] -> values[0]
[boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n]
[boundaries[n], end] -> values[n+1]
Defaults to None.
The learning rate is computed as follows: [0, boundaries[0]] ->
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None.
"""
name: str = 'PiecewiseConstantDecay'
boundaries: Optional[List[int]] = None
......@@ -74,10 +71,9 @@ class ExponentialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to ExponentialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
......@@ -97,10 +93,9 @@ class PolynomialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PolynomialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
end_learning_rate: A float. The minimal end learning rate.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
......@@ -123,10 +118,9 @@ class CosineLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to CosineDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
"""
......@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config):
name: str = 'polynomial'
power: float = 1
warmup_steps: Optional[int] = None
......@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase):
'type': 'linear'
}
})
self.assertEqual(opt_config.optimizer.get(),
opt_cfg.SGDConfig())
self.assertEqual(opt_config.optimizer.get(), opt_cfg.SGDConfig())
self.assertEqual(opt_config.learning_rate.get(),
lr_cfg.PolynomialLrConfig())
self.assertEqual(opt_config.warmup.get(),
lr_cfg.LinearWarmupConfig())
self.assertEqual(opt_config.warmup.get(), lr_cfg.LinearWarmupConfig())
if __name__ == '__main__':
tf.test.main()
......@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config):
epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a
substring matching the pattern will be excluded.
weight decay. Variables whose name contain a substring matching the
pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name
contain a substring matching the pattern will
be excluded.
from layer adaptation. Variables whose name contain a substring matching
the pattern will be excluded.
"""
name: str = "LAMB"
beta_1: float = 0.9
......
......@@ -131,8 +131,9 @@ class OptimizerFactory(object):
rate built using self.build_lr() is passed as an argument to this method.
Args:
lr: A floating point value, or
a tf.keras.optimizers.schedules.LearningRateSchedule instance.
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
Returns:
tf.keras.optimizers.Optimizer instance.
"""
......@@ -142,4 +143,3 @@ class OptimizerFactory(object):
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
return optimizer
......@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
@parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'))
def test_optimizers(self, optimizer_type):
params = {
'optimizer': {
......@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}
}
with self.assertRaises(ValueError):
......@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}
expected_lr_step_values = [
[0, 0.1],
[5000, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.1], [5000, 0.1], [10000, 0.1],
[10001, 0.01], [20000, 0.01], [20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.01
}
}
expected_lr_step_values = [
[0, 0.01],
[250, 0.055],
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5500, 0.1],
[10000, 0.1], [10001, 0.01], [20000, 0.01],
[20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'exponential',
......@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
......@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'cosine',
......@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
expected_lr_step_values = [[0, 0.1],
[250, 0.08535534],
[500, 0.04999999],
[750, 0.01464466],
[1000, 0]]
expected_lr_step_values = [[0, 0.1], [250, 0.08535534], [500, 0.04999999],
[750, 0.01464466], [1000, 0]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'constant',
......@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
'stepwise': {
'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
},
'warmup': {
'type': 'polynomial',
'polynomial': {'warmup_steps': 500, 'power': 2.}
'polynomial': {
'warmup_steps': 500,
'power': 2.
}
}
expected_lr_step_values = [
[0, 0.0],
[250, 0.025],
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
}
expected_lr_step_values = [[0, 0.0], [250, 0.025], [500, 0.1], [5500, 0.1],
[10000, 0.1], [10001, 0.01], [20000, 0.01],
[20001, 0.001]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
......@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__':
tf.test.main()
......@@ -21,7 +21,7 @@ import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
loss_scale='dynamic'):
"""Configures optimizer object with performance options."""
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
......@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None):
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
raise ValueError('Unexpected dtype: %s' % dtype)
......@@ -29,8 +29,7 @@ from official.modeling import activations
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed.")
def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple.
......@@ -55,8 +54,7 @@ def pack_inputs(inputs):
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed.")
def unpack_inputs(inputs):
"""unpack a tuple of `inputs` tensors to a tuple.
......
......@@ -133,15 +133,9 @@ class SummaryWriter(object):
class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy.
"""
"""Interface to train and eval models with tf.distribute.Strategy."""
def __init__(self,
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False):
def __init__(self, strategy, params, model_fn, loss_fn, is_multi_host=False):
"""Constructor.
Args:
......@@ -293,8 +287,7 @@ class DistributedExecutor(object):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
per_replica_losses = strategy.run(
replicated_step, args=(next(iterator),))
per_replica_losses = strategy.run(replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run(
replicated_step, args=(next(iterator),))
......@@ -368,6 +361,7 @@ class DistributedExecutor(object):
available checkpoints. If `False`, will do the evaluation once after the
final step.
save_config: bool. Whether to save params to model_dir.
Returns:
The training loss and eval metrics.
"""
......@@ -477,16 +471,15 @@ class DistributedExecutor(object):
# Step-0 operations
if current_step == 0 and not latest_checkpoint_file:
_save_checkpoint(
checkpoint, model_dir, checkpoint_name.format(step=current_step))
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(
test_step, current_step, eval_metric, eval_iterator)
logging.info(
'Step: %s evalation metric = %s.', current_step, eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
reset_states(eval_metric)
logging.info('Training started')
......@@ -519,8 +512,7 @@ class DistributedExecutor(object):
else:
train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_steps, train_loss,
train_metric_result)
current_step, total_steps, train_loss, train_metric_result)
train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations)
......@@ -561,8 +553,7 @@ class DistributedExecutor(object):
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Final evaluation metric = %s.', eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close()
self.eval_summary_writer.close()
......@@ -696,8 +687,7 @@ class DistributedExecutor(object):
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
logging.info(
'Checkpoint file %s found and restoring from '
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path)
status = checkpoint.restore(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
......@@ -755,8 +745,8 @@ class ExecutorBuilder(object):
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index)
_ = distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
"""Constructor.
Args:
......
......@@ -26,10 +26,7 @@ from official.nlp.bert import configs
class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
"""Constructs AlbertConfig.
Args:
......
......@@ -18,6 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import json
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment