Commit 999fae62 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 94561082
...@@ -171,26 +171,30 @@ class Task(tf.Module): ...@@ -171,26 +171,30 @@ class Task(tf.Module):
return [] return []
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API. """Process and update metrics.
Called when using custom training loop API.
Args: Args:
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects. The return of function
The return of function self.build_metrics. self.build_metrics.
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors. For example,
For example, output of the keras model built by self.build_model. output of the keras model built by self.build_model.
""" """
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API. """Process and update compiled_metrics.
call when using compile/fit API.
Args: Args:
compiled_metrics: the compiled metrics (model.compiled_metrics). compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors. For example,
For example, output of the keras model built by self.build_model. output of the keras model built by self.build_model.
""" """
compiled_metrics.update_state(labels, model_outputs) compiled_metrics.update_state(labels, model_outputs)
...@@ -297,4 +301,3 @@ class Task(tf.Module): ...@@ -297,4 +301,3 @@ class Task(tf.Module):
def reduce_aggregated_logs(self, aggregated_logs): def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps.""" """Optional reduce of aggregated logs over validation steps."""
return {} return {}
...@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be `StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -28,7 +29,6 @@ from official.modeling import optimization ...@@ -28,7 +29,6 @@ from official.modeling import optimization
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
...@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
default to True. default to True.
evaluate: bool, whether or not this trainer will be used for evaluation. evaluate: bool, whether or not this trainer will be used for evaluation.
default to True. default to True.
model: tf.keras.Model instance. If provided, it will be used instead model: tf.keras.Model instance. If provided, it will be used instead of
of building model using task.build_model(). Default to None. building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None. used instead of the optimizer from config. Default to None.
""" """
...@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else: else:
checkpoint_items = {} checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, model=self.model, global_step=self.global_step,
optimizer=self.optimizer, **checkpoint_items) model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean( self._validation_loss = tf.keras.metrics.Mean(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# ============================================================================== # ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer.""" """Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
...@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
self._config = cfg.ExperimentConfig( self._config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig( optimizer_config=cfg.OptimizationConfig({
{'optimizer': { 'optimizer': {
'type': 'sgd' 'type': 'sgd'
}, },
'learning_rate': { 'learning_rate': {
'type': 'constant' 'type': 'constant'
}}))) }
})))
def create_test_trainer(self): def create_test_trainer(self):
task = mock_task.MockTask() task = mock_task.MockTask()
...@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
runtime=cfg.RuntimeConfig( runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig( optimizer_config=cfg.OptimizationConfig({
{'optimizer': { 'optimizer': {
'type': 'sgd' 'type': 'sgd'
}, },
'learning_rate': { 'learning_rate': {
'type': 'constant' 'type': 'constant'
}}))) }
})))
task = mock_task.MockTask() task = mock_task.MockTask()
trainer = trainer_lib.Trainer(config, task) trainer = trainer_lib.Trainer(config, task)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype != 'float16':
......
...@@ -136,10 +136,11 @@ class Config(params_dict.ParamsDict): ...@@ -136,10 +136,11 @@ class Config(params_dict.ParamsDict):
return subconfig_type return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs): def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params, super().__init__(
restrictions=restrictions, default_params=default_params,
*args, restrictions=restrictions,
**kwargs) *args,
**kwargs)
def _set(self, k, v): def _set(self, k, v):
"""Overrides same method in ParamsDict. """Overrides same method in ParamsDict.
......
...@@ -55,14 +55,14 @@ class DataConfig(base_config.Config): ...@@ -55,14 +55,14 @@ class DataConfig(base_config.Config):
exhaust all the examples in the dataset. exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data. tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS. tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
the returned tf.data.Dataset will have a 2-tuple structure (input, label) returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, according to builder.info.supervised_keys; if False, the default, the
the returned tf.data.Dataset will have a dictionary with all the features. returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped tfds_skip_decoding_feature: A str to indicate which features are skipped for
for decoding when loading dataset from TFDS. Use comma to separate decoding when loading dataset from TFDS. Use comma to separate multiple
multiple features. The main use case is to skip the image/video decoding features. The main use case is to skip the image/video decoding for better
for better performance. performance.
""" """
input_path: str = "" input_path: str = ""
tfds_name: str = "" tfds_name: str = ""
...@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config): ...@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. checkpoints, if set to None, continuous eval will wait indefinitely. This
This is only used continuous_train_and_eval and continuous_eval modes. is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps. train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset validation_steps: number of eval steps. If `None`, the entire eval dataset
is used. is used.
...@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config): ...@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config):
task: TaskConfig = TaskConfig() task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig() trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig() runtime: RuntimeConfig = RuntimeConfig()
...@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config): ...@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config):
if self.type is None: if self.type is None:
return {'type': None} return {'type': None}
elif self.__dict__['type'] not in self.__dict__: elif self.__dict__['type'] not in self.__dict__:
raise ValueError( raise ValueError('type: {!r} is not a valid key!'.format(
'type: {!r} is not a valid key!'.format(self.__dict__['type'])) self.__dict__['type']))
else: else:
chosen_type = self.type chosen_type = self.type
chosen_value = self.__dict__[chosen_type] chosen_value = self.__dict__[chosen_type]
return { return {'type': self.type, chosen_type: self._export_config(chosen_value)}
'type': self.type,
chosen_type: self._export_config(chosen_value)
}
def get(self): def get(self):
"""Returns selected config based on the value of type. """Returns selected config based on the value of type.
...@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config): ...@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config):
if chosen_type is None: if chosen_type is None:
return None return None
if chosen_type not in self.__dict__: if chosen_type not in self.__dict__:
raise ValueError( raise ValueError('type: {!r} is not a valid key!'.format(self.type))
'type: {!r} is not a valid key!'.format(self.type))
return self.__dict__[chosen_type] return self.__dict__[chosen_type]
...@@ -48,12 +48,18 @@ class Network(base_config.Config): ...@@ -48,12 +48,18 @@ class Network(base_config.Config):
class OneOfTest(tf.test.TestCase): class OneOfTest(tf.test.TestCase):
def test_to_dict(self): def test_to_dict(self):
network_params = {'backbone': {'type': 'resnet', network_params = {
'resnet': {'model_depth': 50} 'backbone': {
}, 'type': 'resnet',
'output_layer': {'type': 'single', 'resnet': {
'single': 1000} 'model_depth': 50
} }
},
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params) network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params) self.assertEqual(network_config.as_dict(), network_params)
......
...@@ -30,7 +30,8 @@ import yaml ...@@ -30,7 +30,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and # key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single # matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets. # values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r""" _PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s* \s*=\s*
((?P<val>\'(.*?)\' # single quote ((?P<val>\'(.*?)\' # single quote
...@@ -138,8 +139,8 @@ class ParamsDict(object): ...@@ -138,8 +139,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked. ValueError: if the ParamsDict instance has been locked.
""" """
if k in ParamsDict.RESERVED_ATTR: if k in ParamsDict.RESERVED_ATTR:
raise AttributeError('The key `{}` is reserved. No change is allowes. ' raise AttributeError(
.format(k)) 'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys(): if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k)) raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked: if self._locked:
...@@ -150,13 +151,13 @@ class ParamsDict(object): ...@@ -150,13 +151,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params. """Override the ParamsDict with a set of given params.
Args: Args:
override_params: a dict or a ParamsDict specifying the parameters to override_params: a dict or a ParamsDict specifying the parameters to be
be overridden. overridden.
is_strict: a boolean specifying whether override is strict or not. If is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. True, keys in `override_params` must be present in the ParamsDict. If
If False, keys in `override_params` can be different from what is False, keys in `override_params` can be different from what is currently
currently defined in the ParamsDict. In this case, the ParamsDict will defined in the ParamsDict. In this case, the ParamsDict will be extended
be extended to include the new keys. to include the new keys.
""" """
if self._locked: if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.') raise ValueError('The ParamsDict has been locked. No change is allowed.')
...@@ -240,6 +241,7 @@ class ParamsDict(object): ...@@ -240,6 +241,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found. (2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported. ValueError: if the restriction defined in the string is not supported.
""" """
def _get_kv(dotted_string, params_dict): def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string.""" """Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None: if _CONST_VALUE_RE.match(dotted_string) is not None:
...@@ -270,38 +272,44 @@ class ParamsDict(object): ...@@ -270,38 +272,44 @@ class ParamsDict(object):
tokens = restriction.split('==') tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v: if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction: elif '!=' in restriction:
tokens = restriction.split('!=') tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v: if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction: elif '<' in restriction:
tokens = restriction.split('<') tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v: if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction: elif '<=' in restriction:
tokens = restriction.split('<=') tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v: if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction: elif '>' in restriction:
tokens = restriction.split('>') tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v: if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction: elif '>=' in restriction:
tokens = restriction.split('>=') tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v: if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else: else:
raise ValueError('Unsupported relation in restriction.') raise ValueError('Unsupported relation in restriction.')
...@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path): ...@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path):
def save_params_dict_to_yaml(params, file_path): def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file.""" """Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data): def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence. # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence( return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True) u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep) yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False) yaml.dump(params.as_dict(), f, default_flow_style=False)
...@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): ...@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args: Args:
params: a ParamsDict object to be overridden. params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
path to a YAML file specifying the parameters to be overridden. a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not. is_strict: a boolean specifying whether override is strict or not.
Returns: Returns:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for params_dict.py.""" """Tests for params_dict.py."""
import os import os
...@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def test_setattr(self): def test_setattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc' params.c = 'ccc'
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
...@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def test_getattr(self): def test_getattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
self.assertEqual(params.c, None) self.assertEqual(params.c, None)
def test_delattr(self): def test_delattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({
{'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}}, 'a': 'aa',
is_strict=False) 'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c del params.c
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
...@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def test_contains(self): def test_contains(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa'}, is_strict=False)
{'a': 'aa'}, is_strict=False)
self.assertIn('a', params) self.assertIn('a', params)
self.assertNotIn('b', params) self.assertNotIn('b', params)
def test_get(self): def test_get(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa'}, is_strict=False)
{'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa') self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2) self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None) self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self): def test_override_is_strict_true(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True) params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2) self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc') self.assertEqual(params.c.c1, 'ccc')
...@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True) params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self): def test_override_is_strict_false(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False) params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2) self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000) self.assertEqual(params.c.c3, 3000)
...@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.c.c4, 4444) self.assertEqual(params.c.c4, 4444)
def test_as_dict(self): def test_as_dict(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict() params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa') self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2) self.assertEqual(params_d['b'], 2)
...@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase):
def test_validate(self): def test_validate(self):
# Raise error due to the unknown parameter. # Raise error due to the unknown parameter.
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
{'a': 1, 'b': {'a': 11}}, ['a == c'])
# OK to check equality of two nested dicts. # OK to check equality of two nested dicts.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c']) 'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency # Raise error due to inconsistency
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
# Valid rule. # Valid rule.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate. # Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11}) params.override({'a': 11})
...@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase):
params.validate() params.validate()
# Valid restrictions with constant. # Valid restrictions with constant.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1']) 'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate() params.validate()
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1']) 'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
class ParamsDictIOTest(tf.test.TestCase): class ParamsDictIOTest(tf.test.TestCase):
...@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return temp_file return temp_file
def test_save_params_dict_to_yaml(self): def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml') output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file) params_dict.save_params_dict_to_yaml(params, output_yaml_file)
...@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_dict(self): def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]} override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_dict, is_strict=True) params, override_dict, is_strict=True)
...@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_string(self): def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]" override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True) params, override_yaml_string, is_strict=True)
...@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_json_string(self): def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],}, 'a': 1,
'd': {'d1': {'d2': 'hello'}}, 'e': False}) 'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }" override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_json_string, is_strict=True) params, override_json_string, is_strict=True)
...@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_csv_string(self): def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],}, 'a': 1,
'd': {'d1': {'d2': 'hello'}}, 'e': False}) 'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test" override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True) params, override_csv_string, is_strict=True)
...@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_file(self): def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file( override_yaml_file = self.write_temp_file(
'params.yaml', r""" 'params.yaml', r"""
b: 5.2 b: 5.2
...@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase): ...@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase):
def test_csv_str_load_unsupported_datatypes(self): def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]' csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError, self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
params_dict.nested_csv_str_to_json_str,
csv_str) csv_str)
def test_csv_str_to_json_str_spacing(self): def test_csv_str_to_json_str_spacing(self):
......
...@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config): ...@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config):
Attributes: Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant. name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries. boundaries: A list of ints of strictly increasing entries. Defaults to None.
Defaults to None.
values: A list of floats that specifies the values for the intervals defined values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`. by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows: The learning rate is computed as follows: [0, boundaries[0]] ->
[0, boundaries[0]] -> values[0] values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[0], boundaries[1]] -> values[1] [boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
[boundaries[n-1], boundaries[n]] -> values[n] end] -> values[n+1] Defaults to None.
[boundaries[n], end] -> values[n+1]
Defaults to None.
""" """
name: str = 'PiecewiseConstantDecay' name: str = 'PiecewiseConstantDecay'
boundaries: Optional[List[int]] = None boundaries: Optional[List[int]] = None
...@@ -74,13 +71,12 @@ class ExponentialLrConfig(base_config.Config): ...@@ -74,13 +71,12 @@ class ExponentialLrConfig(base_config.Config):
Attributes: Attributes:
name: The name of the learning rate schedule. Defaults to ExponentialDecay. name: The name of the learning rate schedule. Defaults to ExponentialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to initial_learning_rate: A float. The initial learning rate. Defaults to None.
None. decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation. to None.
Defaults to None.
decay_rate: A float. Defaults to None. decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False. intervals. Defaults to False.
""" """
name: str = 'ExponentialDecay' name: str = 'ExponentialDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
...@@ -97,14 +93,13 @@ class PolynomialLrConfig(base_config.Config): ...@@ -97,14 +93,13 @@ class PolynomialLrConfig(base_config.Config):
Attributes: Attributes:
name: The name of the learning rate schedule. Defaults to PolynomialDecay. name: The name of the learning rate schedule. Defaults to PolynomialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to initial_learning_rate: A float. The initial learning rate. Defaults to None.
None. decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation. to None.
Defaults to None.
end_learning_rate: A float. The minimal end learning rate. end_learning_rate: A float. The minimal end learning rate.
power: A float. The power of the polynomial. Defaults to linear, 1.0. power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps. cycle: A boolean, whether or not it should cycle beyond decay_steps.
Defaults to False. Defaults to False.
""" """
name: str = 'PolynomialDecay' name: str = 'PolynomialDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
...@@ -123,12 +118,11 @@ class CosineLrConfig(base_config.Config): ...@@ -123,12 +118,11 @@ class CosineLrConfig(base_config.Config):
Attributes: Attributes:
name: The name of the learning rate schedule. Defaults to CosineDecay. name: The name of the learning rate schedule. Defaults to CosineDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to initial_learning_rate: A float. The initial learning rate. Defaults to None.
None. decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation. to None.
Defaults to None.
alpha: A float. Minimum learning rate value as a fraction of alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate. initial_learning_rate.
""" """
name: str = 'CosineDecay' name: str = 'CosineDecay'
initial_learning_rate: Optional[float] = None initial_learning_rate: Optional[float] = None
...@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config): ...@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config):
name: str = 'polynomial' name: str = 'polynomial'
power: float = 1 power: float = 1
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
...@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase): ...@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase):
'type': 'linear' 'type': 'linear'
} }
}) })
self.assertEqual(opt_config.optimizer.get(), self.assertEqual(opt_config.optimizer.get(), opt_cfg.SGDConfig())
opt_cfg.SGDConfig())
self.assertEqual(opt_config.learning_rate.get(), self.assertEqual(opt_config.learning_rate.get(),
lr_cfg.PolynomialLrConfig()) lr_cfg.PolynomialLrConfig())
self.assertEqual(opt_config.warmup.get(), self.assertEqual(opt_config.warmup.get(), lr_cfg.LinearWarmupConfig())
lr_cfg.LinearWarmupConfig())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -72,7 +72,7 @@ class AdamConfig(base_config.Config): ...@@ -72,7 +72,7 @@ class AdamConfig(base_config.Config):
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer. epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
""" """
name: str = "Adam" name: str = "Adam"
beta_1: float = 0.9 beta_1: float = 0.9
...@@ -91,12 +91,12 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -91,12 +91,12 @@ class AdamWeightDecayConfig(base_config.Config):
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer. epsilon: epsilon value used for numerical stability in the optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
weight_decay_rate: float. Weight decay rate. Default to 0. weight_decay_rate: float. Weight decay rate. Default to 0.
include_in_weight_decay: list[str], or None. List of weight names to include include_in_weight_decay: list[str], or None. List of weight names to include
in weight decay. in weight decay.
include_in_weight_decay: list[str], or None. List of weight names to not include_in_weight_decay: list[str], or None. List of weight names to not
include in weight decay. include in weight decay.
""" """
name: str = "AdamWeightDecay" name: str = "AdamWeightDecay"
beta_1: float = 0.9 beta_1: float = 0.9
...@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config): ...@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config):
epsilon: epsilon value used for numerical stability in LAMB optimizer. epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0. weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a weight decay. Variables whose name contain a substring matching the
substring matching the pattern will be excluded. pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name from layer adaptation. Variables whose name contain a substring matching
contain a substring matching the pattern will the pattern will be excluded.
be excluded.
""" """
name: str = "LAMB" name: str = "LAMB"
beta_1: float = 0.9 beta_1: float = 0.9
......
...@@ -131,8 +131,9 @@ class OptimizerFactory(object): ...@@ -131,8 +131,9 @@ class OptimizerFactory(object):
rate built using self.build_lr() is passed as an argument to this method. rate built using self.build_lr() is passed as an argument to this method.
Args: Args:
lr: A floating point value, or lr: A floating point value, or a
a tf.keras.optimizers.schedules.LearningRateSchedule instance. tf.keras.optimizers.schedules.LearningRateSchedule instance.
Returns: Returns:
tf.keras.optimizers.Optimizer instance. tf.keras.optimizers.Optimizer instance.
""" """
...@@ -142,4 +143,3 @@ class OptimizerFactory(object): ...@@ -142,4 +143,3 @@ class OptimizerFactory(object):
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
return optimizer return optimizer
...@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config ...@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'))
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
def test_optimizers(self, optimizer_type): def test_optimizers(self, optimizer_type):
params = { params = {
'optimizer': { 'optimizer': {
...@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self): def test_missing_types(self):
params = { params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory( optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params)) optimization_config.OptimizationConfig(params))
params = { params = {
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000], 'stepwise': {
'values': [0.1, 0.01, 0.001]} 'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
} }
} }
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000], 'stepwise': {
'values': [0.1, 0.01, 0.001]} 'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
} }
} }
expected_lr_step_values = [ expected_lr_step_values = [[0, 0.1], [5000, 0.1], [10000, 0.1],
[0, 0.1], [10001, 0.01], [20000, 0.01], [20001, 0.001]]
[5000, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
...@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000], 'stepwise': {
'values': [0.1, 0.01, 0.001]} 'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}, },
'warmup': { 'warmup': {
'type': 'linear', 'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01} 'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.01
}
} }
} }
expected_lr_step_values = [ expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5500, 0.1],
[0, 0.01], [10000, 0.1], [10001, 0.01], [20000, 0.01],
[250, 0.055], [20001, 0.001]]
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
...@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'exponential', 'type': 'exponential',
...@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
...@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
...@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
} }
} }
} }
expected_lr_step_values = [[0, 0.1], expected_lr_step_values = [[0, 0.1], [250, 0.08535534], [500, 0.04999999],
[250, 0.08535534], [750, 0.01464466], [1000, 0]]
[500, 0.04999999],
[750, 0.01464466],
[1000, 0]]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
...@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'constant', 'type': 'constant',
...@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'momentum': 0.9} 'sgd': {
'momentum': 0.9
}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000], 'stepwise': {
'values': [0.1, 0.01, 0.001]} 'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]
}
}, },
'warmup': { 'warmup': {
'type': 'polynomial', 'type': 'polynomial',
'polynomial': {'warmup_steps': 500, 'power': 2.} 'polynomial': {
'warmup_steps': 500,
'power': 2.
}
} }
} }
expected_lr_step_values = [ expected_lr_step_values = [[0, 0.0], [250, 0.025], [500, 0.1], [5500, 0.1],
[0, 0.0], [10000, 0.1], [10001, 0.01], [20000, 0.01],
[250, 0.025], [20001, 0.001]]
[500, 0.1],
[5500, 0.1],
[10000, 0.1],
[10001, 0.01],
[20000, 0.01],
[20001, 0.001]
]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
...@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale="dynamic"): loss_scale='dynamic'):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_float16: if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically # Wraps optimizer with a LossScaleOptimizer. This is done automatically
...@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None): ...@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None):
'mixed_float16', loss_scale=loss_scale) 'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16: elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy( policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32: elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
else: else:
raise ValueError("Unexpected dtype: %s" % dtype) raise ValueError('Unexpected dtype: %s' % dtype)
...@@ -29,8 +29,7 @@ from official.modeling import activations ...@@ -29,8 +29,7 @@ from official.modeling import activations
None, None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as " "tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer " "input tensors. pack/unpack inputs to override __call__ is no longer "
"needed." "needed.")
)
def pack_inputs(inputs): def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple. """Pack a list of `inputs` tensors to a tuple.
...@@ -55,8 +54,7 @@ def pack_inputs(inputs): ...@@ -55,8 +54,7 @@ def pack_inputs(inputs):
None, None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as " "tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer " "input tensors. pack/unpack inputs to override __call__ is no longer "
"needed." "needed.")
)
def unpack_inputs(inputs): def unpack_inputs(inputs):
"""unpack a tuple of `inputs` tensors to a tuple. """unpack a tuple of `inputs` tensors to a tuple.
......
...@@ -133,15 +133,9 @@ class SummaryWriter(object): ...@@ -133,15 +133,9 @@ class SummaryWriter(object):
class DistributedExecutor(object): class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy. """Interface to train and eval models with tf.distribute.Strategy."""
"""
def __init__(self, def __init__(self, strategy, params, model_fn, loss_fn, is_multi_host=False):
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False):
"""Constructor. """Constructor.
Args: Args:
...@@ -293,8 +287,7 @@ class DistributedExecutor(object): ...@@ -293,8 +287,7 @@ class DistributedExecutor(object):
raise ValueError('steps should be an Tensor. Python object may cause ' raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.') 'retracing.')
per_replica_losses = strategy.run( per_replica_losses = strategy.run(replicated_step, args=(next(iterator),))
replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1): for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run( per_replica_losses = strategy.run(
replicated_step, args=(next(iterator),)) replicated_step, args=(next(iterator),))
...@@ -368,6 +361,7 @@ class DistributedExecutor(object): ...@@ -368,6 +361,7 @@ class DistributedExecutor(object):
available checkpoints. If `False`, will do the evaluation once after the available checkpoints. If `False`, will do the evaluation once after the
final step. final step.
save_config: bool. Whether to save params to model_dir. save_config: bool. Whether to save params to model_dir.
Returns: Returns:
The training loss and eval metrics. The training loss and eval metrics.
""" """
...@@ -477,16 +471,15 @@ class DistributedExecutor(object): ...@@ -477,16 +471,15 @@ class DistributedExecutor(object):
# Step-0 operations # Step-0 operations
if current_step == 0 and not latest_checkpoint_file: if current_step == 0 and not latest_checkpoint_file:
_save_checkpoint( _save_checkpoint(checkpoint, model_dir,
checkpoint, model_dir, checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if test_step: if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation( eval_metric_result = self._run_evaluation(test_step, current_step,
test_step, current_step, eval_metric, eval_iterator) eval_metric, eval_iterator)
logging.info( logging.info('Step: %s evalation metric = %s.', current_step,
'Step: %s evalation metric = %s.', current_step, eval_metric_result) eval_metric_result)
test_summary_writer( test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
metrics=eval_metric_result, step=optimizer.iterations)
reset_states(eval_metric) reset_states(eval_metric)
logging.info('Training started') logging.info('Training started')
...@@ -519,8 +512,7 @@ class DistributedExecutor(object): ...@@ -519,8 +512,7 @@ class DistributedExecutor(object):
else: else:
train_metric_result.update({'learning_rate': optimizer.lr.numpy()}) train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
logging.info('Train Step: %d/%d / loss = %s / training metric = %s', logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_steps, train_loss, current_step, total_steps, train_loss, train_metric_result)
train_metric_result)
train_summary_writer( train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations) metrics=train_metric_result, step=optimizer.iterations)
...@@ -561,8 +553,7 @@ class DistributedExecutor(object): ...@@ -561,8 +553,7 @@ class DistributedExecutor(object):
eval_metric_result = self._run_evaluation(test_step, current_step, eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator) eval_metric, eval_iterator)
logging.info('Final evaluation metric = %s.', eval_metric_result) logging.info('Final evaluation metric = %s.', eval_metric_result)
test_summary_writer( test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close() self.train_summary_writer.close()
self.eval_summary_writer.close() self.eval_summary_writer.close()
...@@ -696,9 +687,8 @@ class DistributedExecutor(object): ...@@ -696,9 +687,8 @@ class DistributedExecutor(object):
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor( current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE') 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
logging.info( logging.info('Checkpoint file %s found and restoring from '
'Checkpoint file %s found and restoring from ' 'checkpoint', checkpoint_path)
'checkpoint', checkpoint_path)
status = checkpoint.restore(checkpoint_path) status = checkpoint.restore(checkpoint_path)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
...@@ -755,8 +745,8 @@ class ExecutorBuilder(object): ...@@ -755,8 +745,8 @@ class ExecutorBuilder(object):
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster( _ = distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.worker_hosts, strategy_config.task_index) strategy_config.task_index)
"""Constructor. """Constructor.
Args: Args:
......
...@@ -26,10 +26,7 @@ from official.nlp.bert import configs ...@@ -26,10 +26,7 @@ from official.nlp.bert import configs
class AlbertConfig(configs.BertConfig): class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`.""" """Configuration for `ALBERT`."""
def __init__(self, def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
"""Constructs AlbertConfig. """Constructs AlbertConfig.
Args: Args:
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import json import json
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment