Commit 3d61d6b3 authored by qianyj's avatar qianyj
Browse files

initial files for ResNet50

parent d3a70caf
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.actions.conditional_action."""
from orbit import actions
import tensorflow as tf
class ConditionalActionTest(tf.test.TestCase):
def test_conditional_action(self):
# Define a function to raise an AssertionError, since we can't in a lambda.
def raise_assertion(arg):
raise AssertionError(str(arg))
conditional_action = actions.ConditionalAction(
condition=lambda x: x['value'], action=raise_assertion)
conditional_action({'value': False}) # Nothing is raised.
with self.assertRaises(AssertionError) as ctx:
conditional_action({'value': True})
self.assertEqual(ctx.exception.message, "{'value': True}")
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides the `ExportSavedModel` action and associated helper classes."""
import re
from typing import Callable, Optional
import tensorflow as tf
def _id_key(filename):
_, id_num = filename.rsplit('-', maxsplit=1)
return int(id_num)
def _find_managed_files(base_name):
r"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$')
filenames = tf.io.gfile.glob(f'{base_name}-*')
filenames = filter(managed_file_regex.match, filenames)
return sorted(filenames, key=_id_key)
class _CounterIdFn:
"""Implements a counter-based ID function for `ExportFileManager`."""
def __init__(self, base_name: str):
managed_files = _find_managed_files(base_name)
self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0
def __call__(self):
output = self.value
self.value += 1
return output
class ExportFileManager:
"""Utility class that manages a group of files with a shared base name.
For actions like SavedModel exporting, there are potentially many different
file naming and cleanup strategies that may be desirable. This class provides
a basic interface allowing SavedModel export to be decoupled from these
details, and a default implementation that should work for many basic
scenarios. Users may subclass this class to alter behavior and define more
customized naming and cleanup strategies.
"""
def __init__(self,
base_name: str,
max_to_keep: int = 5,
next_id_fn: Optional[Callable[[], int]] = None):
"""Initializes the instance.
Args:
base_name: A shared base name for file names generated by this class.
max_to_keep: The maximum number of files matching `base_name` to keep
after each call to `cleanup`. The most recent (as determined by file
modification time) `max_to_keep` files are preserved; the rest are
deleted. If < 0, all files are preserved.
next_id_fn: An optional callable that returns integer IDs to append to
base name (formatted as `'{base_name}-{id}'`). The order of integers is
used to sort files to determine the oldest ones deleted by `clean_up`.
If not supplied, a default ID based on an incrementing counter is used.
One common alternative maybe be to use the current global step count,
for instance passing `next_id_fn=global_step.numpy`.
"""
self._base_name = base_name
self._max_to_keep = max_to_keep
self._next_id_fn = next_id_fn or _CounterIdFn(base_name)
@property
def managed_files(self):
"""Returns all files managed by this instance, in sorted order.
Returns:
The list of files matching the `base_name` provided when constructing this
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
"""
return _find_managed_files(self._base_name)
def clean_up(self):
"""Cleans up old files matching `{base_name}-*`.
The most recent `max_to_keep` files are preserved.
"""
if self._max_to_keep < 0:
return
for filename in self.managed_files[:-self._max_to_keep]:
tf.io.gfile.rmtree(filename)
def next_name(self) -> str:
"""Returns a new file name based on `base_name` and `next_id_fn()`."""
return f'{self._base_name}-{self._next_id_fn()}'
class ExportSavedModel:
"""Action that exports the given model as a SavedModel."""
def __init__(self,
model: tf.Module,
file_manager: ExportFileManager,
signatures,
options: Optional[tf.saved_model.SaveOptions] = None):
"""Initializes the instance.
Args:
model: The model to export.
file_manager: An instance of `ExportFileManager` (or a subclass), that
provides file naming and cleanup functionality.
signatures: The signatures to forward to `tf.saved_model.save()`.
options: Optional options to forward to `tf.saved_model.save()`.
"""
self.model = model
self.file_manager = file_manager
self.signatures = signatures
self.options = options
def __call__(self, _):
"""Exports the SavedModel."""
export_dir = self.file_manager.next_name()
tf.saved_model.save(self.model, export_dir, self.signatures, self.options)
self.file_manager.clean_up()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.actions.export_saved_model."""
import os
from orbit import actions
import tensorflow as tf
def _id_key(name):
_, id_num = name.rsplit('-', maxsplit=1)
return int(id_num)
def _id_sorted_file_base_names(dir_path):
return sorted(tf.io.gfile.listdir(dir_path), key=_id_key)
class TestModel(tf.Module):
def __init__(self):
self.value = tf.Variable(0)
@tf.function(input_signature=[])
def __call__(self):
return self.value
class ExportSavedModelTest(tf.test.TestCase):
def test_export_file_manager_default_ids(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
directory.create_file(manager.next_name())
manager.clean_up() # Shouldn't do anything...
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
directory.create_file(manager.next_name())
manager.clean_up() # Shouldn't do anything...
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2)
directory.create_file(manager.next_name())
manager.clean_up() # Shouldn't do anything...
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 4)
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-0', 'basename-1', 'basename-2', 'basename-3'])
manager.clean_up() # Should delete file with lowest ID.
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-1', 'basename-2', 'basename-3'])
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertEqual(os.path.basename(manager.next_name()), 'basename-4')
def test_export_file_manager_custom_ids(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')
id_num = 0
def next_id():
return id_num
manager = actions.ExportFileManager(
base_name, max_to_keep=2, next_id_fn=next_id)
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
id_num = 30
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
manager.clean_up() # Shouldn't do anything...
self.assertEqual(
_id_sorted_file_base_names(directory.full_path), ['basename-30'])
id_num = 200
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2)
manager.clean_up() # Shouldn't do anything...
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-30', 'basename-200'])
id_num = 1000
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-30', 'basename-200', 'basename-1000'])
manager.clean_up() # Should delete file with lowest ID.
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2)
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-200', 'basename-1000'])
def test_export_file_manager_managed_files(self):
directory = self.create_tempdir()
directory.create_file('basename-5')
directory.create_file('basename-10')
directory.create_file('basename-50')
directory.create_file('basename-1000')
directory.create_file('basename-9')
directory.create_file('basename-10-suffix')
base_name = os.path.join(directory.full_path, 'basename')
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertLen(manager.managed_files, 5)
self.assertEqual(manager.next_name(), f'{base_name}-1001')
manager.clean_up()
self.assertEqual(
manager.managed_files,
[f'{base_name}-10', f'{base_name}-50', f'{base_name}-1000'])
def test_export_saved_model(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')
file_manager = actions.ExportFileManager(base_name, max_to_keep=2)
model = TestModel()
export_action = actions.ExportSavedModel(
model, file_manager=file_manager, signatures=model.__call__)
model.value.assign(3)
self.assertEqual(model(), 3)
self.assertEmpty(file_manager.managed_files)
export_action({})
self.assertLen(file_manager.managed_files, 1)
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1])
self.assertEqual(reloaded_model(), 3)
model.value.assign(5)
self.assertEqual(model(), 5)
export_action({})
self.assertLen(file_manager.managed_files, 2)
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1])
self.assertEqual(reloaded_model(), 5)
model.value.assign(7)
self.assertEqual(model(), 7)
export_action({})
self.assertLen(file_manager.managed_files, 2) # Still 2, due to clean up.
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1])
self.assertEqual(reloaded_model(), 7)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides the `NewBestMetric` condition and associated helper classes."""
import json
import os
import sys
from typing import Any, Callable, Optional, Union
import uuid
from orbit import runner
from orbit import utils
import tensorflow as tf
MetricFn = Callable[[runner.Output], Union[float, tf.Tensor]]
class NewBestMetric:
"""Condition that is satisfied when a new best metric is achieved.
This class keeps track of the best metric value seen so far, optionally in a
persistent (preemption-safe) way.
Two methods are provided, which each satisfy the `Action` protocol: `test` for
only testing whether a new best metric is achieved by a given train/eval
output, and `commit`, which both tests and records the new best metric value
if it is achieved. These separate methods enable the same `NewBestMetric`
instance to be reused as a condition multiple times, and can also provide
additional preemption/failure safety. For example, to avoid updating the best
metric if a model export fails or is pre-empted:
new_best_metric = orbit.actions.NewBestMetric(
'accuracy', filename='/model/dir/best_metric')
action = orbit.actions.ConditionalAction(
condition=new_best_metric.test,
action=[
orbit.actions.ExportSavedModel(...),
new_best_metric.commit
])
The default `__call__` implementation is equivalent to `commit`.
This class is safe to use in multi-client settings if all clients can be
guaranteed to compute the same metric. However when saving metrics it may be
helpful to avoid unnecessary writes by setting the `write_value` parameter to
`False` for most clients.
Attributes:
metric: The metric passed to __init__ (may be a string key or a callable
that can be applied to train/eval output).
higher_is_better: Whether higher metric values are better.
"""
def __init__(self,
metric: Union[str, MetricFn],
higher_is_better: bool = True,
filename: Optional[str] = None,
write_metric=True):
"""Initializes the instance.
Args:
metric: Either a string key name to use to look up a metric (assuming the
train/eval output is a dictionary), or a callable that accepts the
train/eval output and returns a metric value.
higher_is_better: Whether higher metric values are better. If `True`, a
new best metric is achieved when the metric value is strictly greater
than the previous best metric. If `False`, a new best metric is achieved
when the metric value is strictly less than the previous best metric.
filename: A filename to use for storage of the best metric value seen so
far, to allow peristence of the value across preemptions. If `None`
(default), values aren't persisted.
write_metric: If `filename` is set, this controls whether this instance
will write new best metric values to the file, or just read from the
file to obtain the initial value. Setting this to `False` for most
clients in some multi-client setups can avoid unnecessary file writes.
Has no effect if `filename` is `None`.
"""
self.metric = metric
self.higher_is_better = higher_is_better
float_max = sys.float_info.max
self._best_value = JSONPersistedValue(
initial_value=-float_max if higher_is_better else float_max,
filename=filename,
write_value=write_metric)
def __call__(self, output: runner.Output) -> bool:
"""Tests `output` and updates the current best value if necessary.
This is equivalent to `commit` below.
Args:
output: The train or eval output to test.
Returns:
`True` if `output` contains a new best metric value, `False` otherwise.
"""
return self.commit(output)
def metric_value(self, output: runner.Output) -> float:
"""Computes the metric value for the given `output`."""
if callable(self.metric):
value = self.metric(output)
else:
value = output[self.metric]
return float(utils.get_value(value))
@property
def best_value(self) -> float:
"""Returns the best metric value seen so far."""
return self._best_value.read()
def test(self, output: runner.Output) -> bool:
"""Tests `output` to see if it contains a new best metric value.
If `output` does contain a new best metric value, this method does *not*
save it (i.e., calling this method multiple times in a row with the same
`output` will continue to return `True`).
Args:
output: The train or eval output to test.
Returns:
`True` if `output` contains a new best metric value, `False` otherwise.
"""
metric_value = self.metric_value(output)
if self.higher_is_better:
if metric_value > self.best_value:
return True
else: # Lower is better.
if metric_value < self.best_value:
return True
return False
def commit(self, output: runner.Output) -> bool:
"""Tests `output` and updates the current best value if necessary.
Unlike `test` above, if `output` does contain a new best metric value, this
method *does* save it (i.e., subsequent calls to this method with the same
`output` will return `False`).
Args:
output: The train or eval output to test.
Returns:
`True` if `output` contains a new best metric value, `False` otherwise.
"""
if self.test(output):
self._best_value.write(self.metric_value(output))
return True
return False
class JSONPersistedValue:
"""Represents a value that is persisted via a file-based backing store.
The value must be JSON-serializable. Each time the value is updated, it will
be written to the backing file. It is only read from the file at
initialization.
"""
def __init__(self,
initial_value: Any,
filename: str,
write_value: bool = True):
"""Initializes the instance.
Args:
initial_value: The initial value to use if no backing file exists or was
given. This must be a JSON-serializable value (possibly nested
combination of lists, dicts, and primitive values).
filename: The path to use for persistent storage of the value. This may be
`None`, in which case the value is not stable across preemptions.
write_value: If `True`, new values will be written to `filename` on calls
to `write()`. If `False`, `filename` is only read once to restore any
persisted value, and new values will not be written to it. This can be
useful in certain multi-client settings to avoid race conditions or
excessive file writes. If `filename` is `None`, this parameter has no
effect.
"""
self._value = None
self._filename = filename
self._write_value = write_value
if self._filename is not None:
if tf.io.gfile.exists(self._filename):
if tf.io.gfile.stat(self._filename).length > 0:
with tf.io.gfile.GFile(self._filename, 'r') as f:
self._value = json.load(f)
elif self._write_value:
tf.io.gfile.makedirs(os.path.dirname(self._filename))
if self._value is None:
self.write(initial_value)
def read(self):
"""Returns the value."""
return self._value
def write(self, value):
"""Writes the value, updating the backing store if one was provided."""
self._value = value
if self._filename is not None and self._write_value:
# To achieve atomic writes, we first write to a temporary file, and then
# rename it to `self._filename`.
tmp_filename = f'{self._filename}.tmp.{uuid.uuid4().hex}'
with tf.io.gfile.GFile(tmp_filename, 'w') as f:
json.dump(self._value, f)
tf.io.gfile.rename(tmp_filename, self._filename, overwrite=True)
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.actions.new_best_metric."""
import os
from orbit import actions
import tensorflow as tf
class NewBestMetricTest(tf.test.TestCase):
def test_new_best_metric_higher_is_better(self):
new_best_metric = actions.NewBestMetric(
lambda x: x['value'], higher_is_better=True)
self.assertTrue(new_best_metric.test({'value': 0.0}))
self.assertTrue(new_best_metric.commit({'value': 0.0}))
self.assertFalse(new_best_metric.test({'value': 0.0}))
self.assertTrue(new_best_metric.test({'value': 1.0}))
def test_new_best_metric_lower_is_better(self):
new_best_metric = actions.NewBestMetric('value', higher_is_better=False)
self.assertTrue(new_best_metric.test({'value': 0.0}))
self.assertTrue(new_best_metric.commit({'value': 0.0}))
self.assertFalse(new_best_metric.test({'value': 0.0}))
self.assertTrue(new_best_metric.test({'value': -1.0}))
def test_new_best_metric_persistence(self):
backing_file = self.create_tempfile()
new_best_metric = actions.NewBestMetric(
'value',
higher_is_better=True,
filename=backing_file.full_path,
write_metric=False)
self.assertTrue(new_best_metric.test({'value': 0.0}))
self.assertTrue(new_best_metric.commit({'value': 0.0}))
self.assertFalse(new_best_metric.test({'value': 0.0}))
new_best_metric = actions.NewBestMetric(
'value', higher_is_better=True, filename=backing_file.full_path)
self.assertLess(new_best_metric.best_value, 0.0)
self.assertTrue(new_best_metric.commit({'value': 5.0}))
self.assertEqual(new_best_metric.best_value, 5.0)
new_best_metric = actions.NewBestMetric(
'value', higher_is_better=True, filename=backing_file.full_path)
self.assertEqual(new_best_metric.best_value, 5.0)
def test_json_persisted_value(self):
tempfile = self.create_tempfile().full_path
value = {'a': 1, 'b': 2}
persisted_value = actions.JSONPersistedValue(value, tempfile)
# The inital value is used since tempfile is empty.
self.assertEqual(persisted_value.read(), value)
persisted_value = actions.JSONPersistedValue('ignored', tempfile)
# Initial value of 'ignored' is ignored, since there's a value in tempfile.
self.assertEqual(persisted_value.read(), value)
value = [1, 2, 3]
persisted_value.write(value)
# Now that a new value is written, it gets read on initialization.
persisted_value = actions.JSONPersistedValue(['also ignored'], tempfile)
self.assertEqual(persisted_value.read(), value)
# Writes can be disabled.
persisted_value = actions.JSONPersistedValue(
'ignored', tempfile, write_value=False)
self.assertEqual(persisted_value.read(), value)
persisted_value.write("won't get persisted")
persisted_value = actions.JSONPersistedValue(
'ignored', tempfile, write_value=False)
self.assertEqual(persisted_value.read(), value)
def test_json_persisted_value_create_dirs(self):
tempfile = os.path.join(self.create_tempdir().full_path, 'subdir/value')
value = {'a': 1, 'b': 2}
# The directory is not created if write_value=False.
actions.JSONPersistedValue(value, tempfile, write_value=False)
self.assertFalse(tf.io.gfile.exists(os.path.dirname(tempfile)))
actions.JSONPersistedValue(value, tempfile)
self.assertTrue(tf.io.gfile.exists(tempfile))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a `Controller` class for managing the outer training loop."""
import pprint
import time
from typing import Callable, List, Optional, Union
from absl import logging
from orbit import runner
from orbit import utils
import psutil
import os
import tensorflow as tf
def _log(message: str):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
logging.ABSLLogger.register_frame_to_skip(__file__, _log.__name__)
def _format_output(output, indent=4):
"""Formats `output`, either on one line, or indented across multiple lines."""
formatted = pprint.pformat(output)
lines = formatted.splitlines()
if len(lines) == 1:
return formatted
lines = [" " * indent + line for line in lines]
return "\n" + "\n".join(lines)
Action = Callable[[runner.Output], None]
class Controller:
"""Class that controls the outer loop of model training and evaluation.
Orbit divides training and evaluation into "inner" and "outer" loops. Inner
loops are implemented by users in the form of `AbstractTrainer` and
`AbstractEvaluator` subclasses, and define how to run a given number of
training or evaluation steps. The outer loop is provided by this `Controller`,
and interleaves calls to the user-provided inner loops with additional actions
such as saving checkpoints, running evaluations, writing summaries, as well as
(optionally) user provided `Action`s (see below).
There are four top-level "outer loops" provided:
- `train`, which trains until a specified number of global steps is reached;
- `evaluate`, for one-off model evaluation;
- `train_and_evaluate`, for interleaved training and evaluation;
- `evaluate_continuously`, for monitoring a given directory and running
evaluations on new model checkpoints.
While this class attempts to provide out-of-the-box solutions for common
training and evaluation use cases, the internal details and method
implementations are also intended to be simple enough to make subclassing or
other custom outer loop implementations easy to achieve.
Some additional customization can be achieved by supplying `train_actions` or
`eval_actions` when constructing the `Controller`. These are just lists of
arbitrary callables that are applied by the `Controller` to the output of
train steps (after each inner loop of `steps_per_loop` steps) or an
evaluation. This provides a hook mechanism, enabling things like reporting
metrics to Vizier, model exporting, additional logging, etc. See the
`orbit.actions` package for a small handful of predefined actions and some
utility classes that may be useful in defining your own.
"""
def __init__(
self,
*, # Makes all args keyword only.
global_step: tf.Variable,
trainer: Optional[runner.AbstractTrainer] = None,
evaluator: Optional[runner.AbstractEvaluator] = None,
strategy: Optional[tf.distribute.Strategy] = None,
# Actions
train_actions: Optional[List[Action]] = None,
eval_actions: Optional[List[Action]] = None,
# Train related
steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# Summary related
summary_interval: Optional[int] = None,
summary_dir: Optional[str] = None,
# Evaluation related
eval_summary_dir: Optional[str] = None,
):
"""Initializes a `Controller` instance.
Note that if `checkpoint_manager` is provided and there are checkpoints in
the associated model directory, the model will be restored from the most
recent checkpoint during this `__init__` method.
Args:
global_step: An integer `tf.Variable` storing the global training step
number. Usually this can be obtained from the `iterations` property of
the model's optimizer (e.g. `trainer.optimizer.iterations`). In cases
where multiple optimizers are used, or if one model "step" corresponds
to more than one update to model parameters, users can create and
increment their own global step variable as well. In this case it is
recommended to create the `tf.Variable` inside the distribution strategy
scope, with `aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA` (see
also `orbit.utils.create_global_step()`).
trainer: An instance of `orbit.AbstractTrainer`, which implements the
inner training loop.
evaluator: An instance of `orbit.AbstractEvaluator`, which implements
evaluation.
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
train_actions: An optional list of `orbit.Action`s to call after each
block of `steps_per_loop` training steps are run. These will be called
with the output of `trainer.train`.
eval_actions: An optional list of `orbit.Action`s to call after each
evaluation. These will be called with the output of
`evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
provided and there are checkpoints in the associated model directory,
the model will be restored from the most recent checkpoint inside this
`__init__` method. If not provided, the `Controller` will not
automatically save to or restore from checkpoints.
summary_interval: Step interval for training summaries. Note that this
argument only applies to `tf.summary` calls inside the `trainer.train`
function. Summaries written by the `Controller` (specifically
"steps_per_second" and output from the `trainer.train` method) will
always be enabled unless the `summary_dir` parameter is `None`. If set,
the value must be divisible by `steps_per_loop`.
summary_dir: The directory to write summaries to. To use the same
directory as for checkpointing, pass `checkpoint_manager.directory`. If
`None`, no training summaries will be written.
eval_summary_dir: The directory to write eval summaries to. If `None`, it
will be set to `summary_dir`. If both `summary_dir` and
`eval_summary_dir` are `None`, no eval summaries will be written.
Raises:
ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `summary_interval` is not a positive integer or is not
divisible by `steps_per_loop`.
"""
if trainer is None and evaluator is None:
raise ValueError("`trainer` and `evaluator` should not both be `None`.")
if trainer is not None:
if steps_per_loop is None:
raise ValueError(
"`steps_per_loop` is required when `trainer` is provided.")
elif not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError(
f"`steps_per_loop` ({steps_per_loop}) must be a positive integer.")
if summary_interval is not None:
if summary_interval <= 0:
raise ValueError(
f"`summary_interval` ({summary_interval}) must be larger than 0.")
elif summary_interval % steps_per_loop != 0:
raise ValueError(
f"`summary interval` ({summary_interval}) must be a multiple "
f"of `steps_per_loop` ({steps_per_loop}).")
if not isinstance(global_step, tf.Variable):
raise ValueError("`global_step` must be a `tf.Variable`.")
self.trainer = trainer
self.evaluator = evaluator
self.strategy = strategy or tf.distribute.get_strategy()
self.train_actions = train_actions or []
self.eval_actions = eval_actions or []
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None:
# Reuse the summary writer if train and evaluation summary directory
# are the same.
self.eval_summary_manager = self.summary_manager
else:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed.
if self.checkpoint_manager is not None:
restored_path = self.restore_checkpoint()
if restored_path:
_log(f"restored from checkpoint: {restored_path}")
def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Runs training until the specified global step count has been reached.
This method makes calls to `self.trainer.train()` until the global step
count is equal to `steps`. It will additionally save checkpoints (if a
`CheckpointManager` was passed to `Controller.__init__`) and summarize
training output (if `summary_dir` is set).
Args:
steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method
returns (regardless of the checkpointing interval). Defaults to `True`.
"""
self._require("trainer", for_method="train")
# TODO(momernick): Support steps=None or -1 (training to exhaustion).
current_step = self.global_step.numpy() # Cache, since this is expensive.
_log(f"train | step: {current_step: 6d} | training until step {steps}...")
while current_step < steps:
# Calculates steps to run for the next train loop.
num_steps = min(steps - current_step, self.steps_per_loop)
self._train_n_steps(num_steps)
self._maybe_save_checkpoint()
current_step = self.global_step.numpy()
if checkpoint_at_completion:
self._maybe_save_checkpoint(check_interval=False)
def evaluate(self, steps: int = -1) -> Optional[runner.Output]:
"""Runs evaluation for the given number of steps.
This method calls `self.evaluator.evaluate(steps)`, then writes the returned
summaries (if any).
Args:
steps: The number of evaluation steps to run. The value `-1` is reserved
as a special sentinel to indicate a "complete" evaluation that runs
until the underlying dataset is exhausted. Support for this is dependent
on the specific `evaluator` being used.
Returns:
The evaluation results as a dictionary mapping names to NumPy values.
Raises:
ValueError: If `evaluator` was not provided to `Controller.__init__`.
ValueError: If no checkpoint is present in `checkpoint_manager.directory`.
ValueError: If `steps` is not a positive value or -1.
"""
self._require("evaluator", for_method="evaluate")
if steps > 0:
steps_msg = f"running {steps} steps of evaluation..."
elif steps == -1:
steps_msg = "running complete evaluation..."
else:
raise ValueError(f"`steps` ({steps}) should be > 0, or == -1.")
current_step = self.global_step.numpy()
_log(f" eval | step: {current_step: 6d} | {steps_msg}")
start = time.time()
with self.eval_summary_manager.summary_writer().as_default():
steps_tensor = tf.convert_to_tensor(steps, dtype=tf.int32)
eval_output = self.evaluator.evaluate(steps_tensor)
elapsed = time.time() - start
eval_output = eval_output or {}
for action in self.eval_actions:
action(eval_output)
eval_output = tf.nest.map_structure(utils.get_value, eval_output)
_log(f" eval | step: {current_step: 6d} | "
f"eval time: {elapsed: 6.1f} sec | "
f"output: {_format_output(eval_output)}")
self.eval_summary_manager.write_summaries(eval_output)
self.eval_summary_manager.flush()
return eval_output
def train_and_evaluate(self,
train_steps: int,
eval_steps: int = -1,
eval_interval: Optional[int] = None) -> None:
"""Runs interleaved training and evaluation.
This method interleaves calls to `self.train()` and `self.evaluate()`,
training the model until the global step count equals `train_steps`, and
running an evaluation for `eval_steps` every `eval_interval` training steps.
In addition, this method will run a final evaluation at the end of the
training sequence.
Args:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If -1, this
method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evaluations. If
set, training will always stop every `eval_interval` steps, even if this
results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
self._require("trainer", for_method="train_and_evaluate")
self._require("evaluator", for_method="train_and_evaluate")
current_step = self.global_step.numpy() # Cache, since this is expensive.
eval_interval = eval_interval or (train_steps - current_step)
while current_step < train_steps:
interval = min(train_steps - current_step, eval_interval)
num_steps = current_step + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy()
self._maybe_save_checkpoint(check_interval=False)
def evaluate_continuously(self,
steps: int = -1,
timeout: Optional[Union[int, float]] = None,
timeout_fn: Optional[Callable[[], bool]] = None):
"""Continuously monitors a directory and evaluates new checkpoints in it.
This method continuously monitors a directory as specified by this
Controller's CheckpointManager init arg and runs evaluation on the
checkpoints found there.
Args:
steps: The number of steps to run when evaluating. If -1, this method will
evaluate over the entire evaluation dataset.
timeout: The maximum number of seconds to wait between checkpoints. See
tf.train.checkpoints_iterator documentation.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` was not provided as a controller init arg.
"""
self._require("evaluator", for_method="evaluate_continuously")
self._require("checkpoint_manager", for_method="evaluate_continuously")
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory,
timeout=timeout,
timeout_fn=timeout_fn):
self.restore_checkpoint(checkpoint_path)
self.evaluate(steps)
def restore_checkpoint(self, checkpoint_path: Optional[str] = None):
"""Restores the model from a checkpoint.
Args:
checkpoint_path: An optional string specifying the checkpoint path to
restore from. If `None`, will restore from the most recent checkpoint
(or initialize the model using a custom `init_fn` if no checkpoints can
be found) using `self.checkpoint_manager.restore_or_initialize()`.
Returns:
The path to the restored checkpoint if a restore happened, or `None` if no
restore occurred.
"""
self._require("checkpoint_manager", for_method="restore_checkpoint")
with self.strategy.scope():
# Checkpoint restoring should be inside scope (b/139450638).
if checkpoint_path is not None:
_log(f"restoring model from {checkpoint_path}...")
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
else:
_log("restoring or initializing model...")
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
if checkpoint_path is not None:
_log(f"restored model from {checkpoint_path}.")
else:
_log("initialized model.")
return checkpoint_path
def save_checkpoint(self):
"""Saves the model to a checkpoint.
This method will save a checkpoint containing the current state of the
model.
Raises:
ValueError: If no `checkpoint_manager` was provided to
`Controller.__init__`.
"""
self._require("checkpoint_manager", for_method="save_checkpoint")
self._maybe_save_checkpoint(check_interval=False)
def _train_n_steps(self, num_steps: int):
"""Runs training for `num_steps` steps.
Also prints/logs updates about training progress, and summarizes training
output (if output is returned from `self.trainer.train()`, and if
`self.summary_dir` is set).
Args:
num_steps: An integer specifying how many steps of training to run.
Raises:
RuntimeError: If `global_step` is not properly incremented by `num_steps`
after calling `self.trainer.train(num_steps)`.
"""
if not self.step_timer:
self.step_timer = StepTimer(self.global_step)
current_step = self.global_step.numpy()
with self.summary_manager.summary_writer().as_default():
should_record = False # Allows static optimization in no-summary cases.
if self.summary_interval:
# Create a predicate to determine when summaries should be written.
should_record = lambda: (self.global_step % self.summary_interval == 0)
with tf.summary.record_if(should_record):
num_steps_tensor = tf.convert_to_tensor(num_steps, dtype=tf.int32)
train_output = self.trainer.train(num_steps_tensor)
# Verify that global_step was updated properly, then update current_step.
expected_step = current_step + num_steps
if self.global_step.numpy() != expected_step:
message = (
f"`trainer.train({num_steps})` did not update `global_step` by "
f"{num_steps}. Old value was {current_step}, expected updated value "
f"to be {expected_step}, but it was {self.global_step.numpy()}.")
logging.warning(message)
train_output = train_output or {}
for action in self.train_actions:
action(train_output)
train_output = tf.nest.map_structure(utils.get_value, train_output)
current_step = self.global_step.numpy()
steps_per_second = self.step_timer.steps_per_second()
_log(f"train | step: {current_step: 6d} | "
f"steps/sec: {steps_per_second: 6.1f} | "
f"output: {_format_output(train_output)}")
train_output["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_output)
self.summary_manager.flush()
def _maybe_save_checkpoint(self, check_interval: bool = True):
"""Conditionally saves a checkpoint.
A checkpoint is saved if a `CheckpointManager` is available, and if the
required number of steps has elapsed since the last checkpoint was saved
(although this condition can be disabled by setting `check_interval=False`).
Args:
check_interval: Whether to check if the checkpoint interval has fully
elapsed. If `False`, a checkpoint is saved regardless of the elapsed
steps since the most recent checkpoint, unless no `checkpoint_manager`
was provided to `Controller.__init__`.
Returns:
A boolean indicating whether a checkpoint was saved.
"""
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=check_interval)
if ckpt_path is not None:
_log(f"saved checkpoint to {ckpt_path}.")
return True
return False
def _require(self, attribute, for_method):
"""Utility method to raise an error if the given `attribute` is not set."""
if getattr(self, attribute, None) is None:
raise ValueError(
f"`{attribute}` is not set. Pass `{attribute}` to "
f"`Controller.__init__` before calling `{for_method}()`.")
class StepTimer:
"""Utility class for measuring steps/second."""
def __init__(self, step):
self.step = step
self.start()
def start(self):
self.last_iteration = self.step.numpy()
self.last_time = time.time()
def steps_per_second(self, restart=True):
value = ((self.step.numpy() - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
return value
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.controller."""
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
from orbit import controller
from orbit import runner
from orbit import standard_runner
import tensorflow as tf
def create_model():
x = tf.keras.layers.Input(shape=(3,), name="input")
y = tf.keras.layers.Dense(4, name="dense")(x)
model = tf.keras.Model(x, y)
return model
def summaries_with_matching_keyword(keyword, summary_dir):
"""Returns summary protos matching given keyword from event file."""
matches = []
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
matches.append(event.summary)
return matches
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self, return_numpy=False):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
self.return_numpy = return_numpy
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardTrainer.__init__(self, train_dataset)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
train_loss = self.train_loss.result()
return {
"loss": train_loss.numpy() if self.return_numpy else train_loss,
}
def build_eval_dataset(self):
return self.strategy.distribute_datasets_from_function(dataset_fn)
def eval_begin(self):
self.eval_loss.reset_states()
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
eval_loss = self.eval_loss.result()
return {
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
}
class TestEvaluator(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def eval_reduce(self, state, output):
state.append(output)
return state
def eval_begin(self):
return []
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
return loss
per_replica_losses = self.strategy.run(
_replicated_step, args=(next(iterator),))
mean_loss = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return mean_loss
def eval_end(self, outputs):
return {
"eval_loss": tf.reduce_mean(outputs),
}
class TestEvaluatorNoOutput(runner.AbstractEvaluator):
def evaluate(self, num_steps):
pass
class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
dataset2 = self.strategy.distribute_datasets_from_function(dataset_fn)
self.loss = tf.keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy = tf.keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
self.loss2 = tf.keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy2 = tf.keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
standard_runner.StandardEvaluator.__init__(
self, eval_dataset={
"dataset": dataset,
"dataset2": dataset2
})
def eval_step(self, iterator):
def _replicated_step(loss, accuracy, inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss.update_state(tf.keras.losses.MSE(targets, outputs))
accuracy.update_state(targets, outputs)
self.strategy.run(
lambda inputs: _replicated_step(self.loss, self.accuracy, inputs),
args=(next(iterator["dataset"]),))
self.strategy.run(
lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs),
args=(next(iterator["dataset2"]),))
def eval_end(self):
return {
"dataset": {
"loss": self.loss.result(),
"accuracy": self.accuracy.result()
},
"dataset2": {
"loss": self.loss2.result(),
"accuracy": self.accuracy2.result()
},
}
class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardTrainer.__init__(
self,
train_dataset,
options=standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True))
def build_train_dataset(self):
return self.strategy.distribute_datasets_from_function(dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
tf.summary.scalar("loss", loss)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runner = TestRunner()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runner.global_step.assign(0)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_no_checkpoint_and_summaries(self):
test_runner = TestRunner()
# No checkpoint + summary directories.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# No summaries are saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Training summaries are not saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
# Evaluation summaries are saved.
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
def test_restore_from_most_recent_checkpoint(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=5)
test_controller.train(20)
self.assertLen(checkpoint_manager.checkpoints, 4)
restored_path = test_controller.restore_checkpoint()
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
@parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
def test_train_and_evaluate(self, return_numpy):
test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_evaluate_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
eval_results = test_controller.evaluate(steps=2)
# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
self.assertIn("eval_loss", eval_results)
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
def timeout_fn():
with tf.io.gfile.GFile(done_file, "w") as f:
f.write("DONE")
return True
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate_continuously(
timeout=1, timeout_fn=timeout_fn, steps=2)
self.assertNotEmpty(tf.io.gfile.glob(done_file))
def test_no_eval_steps(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
def test_already_trained_model(self):
test_runner = TestRunner()
test_runner.global_step.assign(10)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
def test_summaries_inside_train_fn(self):
test_runner = TestTrainerWithSummaries()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_with_same_summary_dir(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
class EarlyStopController(controller.Controller):
"""A subclass of Controller that supports early stopping."""
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
while self.global_step.numpy() < train_steps:
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
logging.info(
"Training early stopped as eval_loss %s is less than 0.1",
test_runner.eval_loss.result())
return
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = EarlyStopController(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=6, eval_interval=2)
self.assertLess(test_runner.global_step, 10)
def test_evaluate_with_loss_output(self):
test_evaluator = TestEvaluator()
checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.model_dir, max_to_keep=None)
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=5)
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_evaluate_with_no_output(self):
test_controller = controller.Controller(
evaluator=TestEvaluatorNoOutput(),
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
self.assertEqual(test_controller.evaluate(steps=5), {})
def test_train_and_evaluate_reset_datasets(self):
test_runner = TestRunner()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
train_dataset = (
test_runner.strategy.distribute_datasets_from_function(dataset_fn))
eval_dataset = (
test_runner.strategy.distribute_datasets_from_function(dataset_fn))
test_runner.train_dataset = train_dataset
test_runner.eval_dataset = eval_dataset
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
def test_eval_and_checkpoint_interval(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager,
summary_dir=self.model_dir)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 5, 10.
self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)
# Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
def test_evaluate_with_nested_summaries(self):
test_evaluator = TestEvaluatorWithNestedSummary()
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=self.model_dir)
test_controller.evaluate(steps=5)
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset2")))
def test_actions(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
class OutputRecorderAction:
"""Simple `Action` that just saves the outputs passed to `__call__`."""
def __init__(self):
self.outputs = []
def __call__(self, output):
self.outputs.append(output)
train_output_recorder = OutputRecorderAction()
eval_output_recorder = OutputRecorderAction()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
train_actions=[train_output_recorder],
eval_actions=[eval_output_recorder],
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertLen(train_output_recorder.outputs, 5)
for output in train_output_recorder.outputs:
self.assertIn("loss", output)
self.assertGreaterEqual(output["loss"], 0)
self.assertLen(eval_output_recorder.outputs, 2)
for output in eval_output_recorder.outputs:
self.assertIn("eval_loss", output)
self.assertGreaterEqual(output["eval_loss"], 0)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides AbstractTrainer/Evaluator base classes, defining train/eval APIs."""
import abc
from typing import Dict, Optional, Union
import numpy as np
import tensorflow as tf
Output = Dict[str, Union[tf.Tensor, float, np.number, np.ndarray, 'Output']] # pytype: disable=not-supported-yet
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the API required for training."""
@abc.abstractmethod
def train(self, num_steps: tf.Tensor) -> Optional[Output]:
"""Implements `num_steps` steps of training.
This method will be called by the `Controller` to perform the "inner loop"
of training. This inner loop amortizes the cost of bookkeeping associated
with checkpointing, evaluation, and writing summaries. Additionally, the
inner loop can be implemented (if desired) using TensorFlow's looping
constructs (e.g. a `for` loop over a `tf.range` inside a `tf.function`),
which can be necessary for getting optimal performance when running on TPU.
For cases that don't require peak performance, a simple Python loop can be
used instead for simplicity.
Args:
num_steps: The number of training steps to run. Note that it is up to the
model what constitutes a "step", which may involve more than one update
to model parameters (e.g., if training a GAN).
Returns:
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
If a dictionary is returned, it will be written to logs and as TensorBoard
summaries. The dictionary may also be nested, which will generate a
hierarchy of summary directories.
"""
pass
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the API required for evaluation."""
@abc.abstractmethod
def evaluate(self, num_steps: tf.Tensor) -> Optional[Output]:
"""Implements `num_steps` steps of evaluation.
This method will by called the `Controller` to perform an evaluation. The
`num_steps` parameter specifies the number of steps of evaluation to run,
which is specified by the user when calling one of the `Controller`'s
evaluation methods. A special sentinel value of `-1` is reserved to indicate
evaluation should run until the underlying data source is exhausted.
Args:
num_steps: The number of evaluation steps to run. Note that it is up to
the model what constitutes a "step". Evaluations may also want to
support "complete" evaluations when `num_steps == -1`, running until a
given data source is exhausted.
Returns:
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
If a dictionary is returned, it will be written to logs and as TensorBoard
summaries. The dictionary may also be nested, which will generate a
hierarchy of summary directories.
"""
pass
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AbstractTrainer/Evaluator subclasses with added functionality.
The classes in this module provide some additional structure to the bare
`AbstractTrainer`/`AbstractEvaluator` APIs.
Both `StandardTrainer` and `StandardEvaluator` split the train/eval loops into
"begin", "step", and "end" methods, and provide an implementation of the loop
itself that makes calls to the relevant step method.
`StandardTrainer` supports running the loop using the TF while loop construct
for added performance (particularly on TPUs). It additionally provides some
functionality to make writing summaries from inside a model more performant when
running on TPUs.
These classes are intended to work well in common settings, however there may
be use cases these classes don't support (for instance, `StandardEvaluator` in
particular doesn't support running full evaluations over multiple different eval
datasets). Users are encouraged to simply fall back to custom `AbstractTrainer`
and `AbstractEvaluator` subclasses in these cases.
"""
import abc
from typing import Any, Optional
import dataclasses
from orbit import runner
from orbit.utils import loop_fns
import tensorflow as tf
@dataclasses.dataclass(frozen=True)
class StandardTrainerOptions:
"""Advanced options for `orbit.StandardTrainer`.
Attributes:
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
`True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_function: bool = True
use_tf_while_loop: bool = True
use_tpu_summary_optimization: bool = False
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements standard functionality on top of the AbstractTrainer API.
This class structures the training "inner loop" roughly as follows:
train_loop_begin()
for _ in range(num_steps):
train_step(train_iterator)
return train_loop_end()
Calls to `train_loop_begin` and `train_loop_end` are always done in eager
mode, while the loop/`train_step` may be implemented using `tf.while` and/or
`tf.function`, as determined by the `options` passed to `__init__`.
"""
def __init__(self,
train_dataset,
options: Optional[StandardTrainerOptions] = None):
"""Initializes the `StandardTrainer` instance.
Args:
train_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or
`DistributedDataset`.
options: An `orbit.StandardTrainerOptions` instance.
"""
options = options or StandardTrainerOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
if options.use_tpu_summary_optimization and not options.use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported")
self._train_options = options
self._train_dataset = train_dataset
self._train_iter = None
self._train_loop_fn = None
def create_train_loop_fn(self):
"""Creates a training loop from the current step function and options.
Returns:
The train loop function, i.e. wrapper of multiple train steps.
"""
train_step_fn = self.train_step
if self._train_options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if self._train_options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if self._train_options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of training.
Args:
num_steps: The number of training steps to run. This corresponds directly
to the number of calls made to `train_step`.
Returns:
The output of `train_loop_end`.
"""
self.train_loop_begin()
if self._train_loop_fn is None:
self._train_loop_fn = self.create_train_loop_fn()
if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
self._train_loop_fn(self._train_iter, num_steps)
return self.train_loop_end()
def train_loop_begin(self):
"""Called once at the beginning of the training loop.
This method is always called in eager mode, and is a good place to reset
metrics that accumulate values over multiple steps of training.
Note that this method is called before dataset iterator creation.
"""
pass
@abc.abstractmethod
def train_step(self, iterator):
"""Implements one step of training.
What a "step" consists of is up to the implementer. When using distribution
strategies, the call to this method takes place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Note that if `use_tf_function=True`, all the code inside `train_step` should
be compatible with `tf.function` tracing (and in particular, any state
modifications involving `self` should be avoided). In some cases, non-
`tf.function` compatible code can be moved to `train_loop_begin` or
`train_loop_end`, which always execute eagerly.
Args:
iterator: A `tf.nest`-compatible structure of `tf.data.Iterator` or
`DistributedIterator`. The structure of this input matches the structure
of `train_dataset` as passed to `__init__`.
"""
pass
def train_loop_end(self) -> Optional[runner.Output]:
"""Called once at the end of the training loop.
This method is always called in eager mode, and is a good place to get
metric results. The value returned from this function will be returned as-is
from the `train` method implementation provided by `StandardTrainer`.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
@property
def train_dataset(self):
"""The current training dataset."""
return self._train_dataset
@train_dataset.setter
def train_dataset(self, train_dataset):
"""Sets a new training dataset, replacing the current one.
Any unprocessed examples in the current dataset are discarded.
Args:
train_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or
`DistributedDataset`.
"""
self._train_dataset = train_dataset
self._train_iter = None
@dataclasses.dataclass(frozen=True)
class StandardEvaluatorOptions:
"""Advanced options for the `orbit.StandardEvaluator`.
Attributes:
use_tf_function: A boolean indicating whether to apply `tf.function` to the
evaluation loop. This will only affect the body of the loop (involving
`eval_step`); `eval_loop_begin` and `eval_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the evaluation loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
recreate_iterator_for_each_eval: A boolean indicating whether to recreate a
new iterator for the evaluation dataset before each round of evaluation,
which implies each round of evaluation starts from the beginning of
the evaluation dataset. For example, the evaluation dataset is
`[1, 2, 3, 4]`, batch size is 1 and evaluation steps is 2. If `True`, the
data to be evaluated is [1, 2] every time. If `False`, the iterator
state is maintained between calls to `StandardEvaluator.evaluate()`.
"""
use_tf_function: bool = True
use_tf_while_loop: bool = False
recreate_iterator_for_each_eval: bool = True
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.
This class structures evaluation roughly as follows:
state = eval_begin()
for _ in range(num_steps):
step_outputs = eval_step(eval_iterator)
state = eval_reduce(state, step_outputs)
return eval_end(state)
Calls to `eval_begin` and `eval_end` are always done in eager
mode, while `eval_step` may be compiled with `tf.function` as determined by
the `options` passed to `__init__`. `eval_reduce` is in eager mode if
`use_tf_while_loop=False` in `StandardEvaluatorOptions`, but in graph mode if
`use_tf_while_loop=True`.
This class does not support completely evaluating multiple different datasets
(i.e., where every example of each dataset should be processed, as opposed to
running for a fixed number of evaluation steps). A custom `AbstractEvaluator`
is recommended in this case.
"""
def __init__(self,
eval_dataset,
options: Optional[StandardEvaluatorOptions] = None):
"""Initializes the `StandardEvaluator` instance.
Args:
eval_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or
`DistributedDataset`. On TPUs, if users want to exaust the dataset
without specifying number of eval steps, it is recommended to set
`drop_remainder=False` when batching the dataset, so the infrastructure
can handle the last partial batch properly.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
options = options or StandardEvaluatorOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
self._eval_options = options
self._eval_dataset = eval_dataset
self._eval_iter = None
self._eval_loop_fn = None
def create_eval_loop_fn(self, has_state: bool):
"""Creates an eval loop from the current step function and options.
Args:
has_state: If the step function has state, state will be kept in the loop.
Returns:
The eval loop function, i.e. wrapper of multiple eval steps.
"""
eval_step_fn = self.eval_step
if self._eval_options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if self._eval_options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
def evaluate(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of evaluation.
Args:
num_steps: The number of evaluation steps to run. When this is -1,
evaluation proceeds until a call to `eval_step` raises a `StopIteration`
or `tf.errors.OutOfRangeError`.
Returns:
The output of `self.eval_end()`.
Raises:
ValueError: If `options.use_tf_while_loop` is `True` and `num_steps` is
unspecified.
"""
if self._eval_options.use_tf_while_loop and num_steps == -1:
raise ValueError("Looping until exhausted is not supported if "
"`options.use_tf_while_loop` is `True`")
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
has_state = outputs is not None
if self._eval_loop_fn is None:
self._eval_loop_fn = self.create_eval_loop_fn(has_state)
# If `recreate_iterator_for_each_eval` is `True`, `self._eval_iter` is
# always None.
if self._eval_iter is None:
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if not self._eval_options.recreate_iterator_for_each_eval:
self._eval_iter = eval_iter
else:
eval_iter = self._eval_iter
if self._eval_options.use_tf_while_loop and not has_state:
self._eval_loop_fn(eval_iter, num_steps)
else:
outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None:
return self.eval_end()
else:
return self.eval_end(outputs)
def eval_begin(self) -> Any:
"""Called once at the beginning of the evaluation.
This method is always called in eager mode, and is a good place to reset
metrics that accumulate values over the course of evaluation.
Note that this method is called before dataset iterator creation.
Returns:
An value to pass as the `state` argument to `eval_reduce`.
"""
pass
@abc.abstractmethod
def eval_step(self, iterator) -> Any:
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. When using distribution
strategies, the call to this method takes place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Note that if `use_tf_function=True`, all the code inside `eval_step` should
be compatible with `tf.function` tracing (and in particular, any state
modifications involving `self` should be avoided). In some cases, non-
`tf.function` compatible code can be moved to `eval_loop_begin`,
`eval_reduce`, or `eval_loop_end`, which always execute eagerly.
Args:
iterator: A `tf.nest`-compatible structure of `tf.data.Iterator` or
`DistributedIterator`.
Returns:
An output which is passed as `step_outputs` argument into `eval_reduce`
function.
"""
pass
def eval_end(self, *args) -> Optional[runner.Output]:
"""Called at the end of the evaluation.
Called once at the end of evaluation.
This method is always called in eager mode, and is a good place to get
metric results. The value returned from this function will be returned as-is
from the `evaluate` method implementation provided by `StandardEvaluator`.
Args:
*args: The outputs from `eval_reduce` for the last eval step, if they are
non-`None` (if they are `None`, nothing is passed).
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
def eval_reduce(self,
state: Optional[Any] = None,
step_outputs: Optional[runner.Output] = None) -> Any:
"""A function to perform per-step reduction on the evaluation outputs.
This is useful for passing state throughout evaluation, especially in cases
where maintaining or accumulating state is hard to accomplish using
`tf.metrics.Metric` or other `tf.Variable`-based approaches. For instance,
it can be used to easily accumulate all per-example losses from the full
evaluation for subsequent processing in `eval_end()`.
Args:
state: A state being mainted throughout the evaluation.
step_outputs: Outputs from the current evaluation step.
Returns:
An output which is passed as the `state` argument to this function for the
next step. After evaluation is finished, the output from last step will be
passed to `eval_end`.
"""
pass
@property
def eval_dataset(self):
"""The current evaluation dataset."""
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, eval_dataset):
"""Sets a new eval dataset, replacing the current one.
Any unprocessed examples in the current dataset are discarded.
Args:
eval_dataset: A `tf.nest`-compatible structure of `tf.data.Dataset` or
`DistributedDataset`.
"""
self._eval_dataset = eval_dataset
self._eval_iter = None
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.standard_runner."""
from absl.testing import parameterized
from orbit import standard_runner
from orbit import utils
import tensorflow as tf
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class TestTrainer(standard_runner.StandardTrainer):
"""A StandardTrainer subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = utils.create_global_step()
dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
super().__init__(train_dataset=dataset, options=options)
def train_loop_begin(self):
self.global_step.assign(0)
def train_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self.strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self):
return self.global_step.numpy()
class TestEvaluator(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = utils.create_global_step()
dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self):
self.global_step.assign(0)
def eval_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self.strategy.run(replica_step, args=(next(iterator),))
def eval_end(self):
return self.global_step.numpy()
class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
dataset = self.strategy.distribute_datasets_from_function(
lambda _: tf.data.Dataset.range(10))
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self):
return {"logits": tf.constant((0.0,))}
def eval_reduce(self, state, step_outputs):
state["logits"] = tf.concat([state["logits"], step_outputs], 0)
return state
def eval_step(self, iterator):
def replica_step(x):
x = tf.cast(x, tf.float32)
return tf.reduce_sum(x)
return self.strategy.experimental_local_results(
self.strategy.run(replica_step, args=(next(iterator),)))
def eval_end(self, outputs):
return tf.reduce_sum(outputs["logits"])
class StandardRunnerTest(parameterized.TestCase):
def test_default_trainer(self):
trainer = TestTrainer()
self.assertEqual(trainer.train(tf.constant(10)), 10)
def test_trainer_with_tpu_summary_optimization(self):
options = standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True)
trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10)
@parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
def test_default_evaluator(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluator(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
@parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
def test_evaluator_with_outputs_aggregation(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 45)
@parameterized.named_parameters(
("recreate_iterator_for_each_eval", True, 10, 10),
("not_recreate_iterator_for_each_eval", False, 10, 35))
def test_evaluator_with_repeat_dataset(self, recreate_iterator_for_each_eval,
sum_for_1st_time, sum_for_2nd_time):
options = standard_runner.StandardEvaluatorOptions(
recreate_iterator_for_each_eval=recreate_iterator_for_each_eval)
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(5)), sum_for_1st_time)
self.assertEqual(evaluator.evaluate(tf.constant(5)), sum_for_2nd_time)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines exported symbols for the `orbit.utils` package."""
from orbit.utils.common import create_global_step
from orbit.utils.common import get_value
from orbit.utils.common import make_distributed_dataset
from orbit.utils.epoch_helper import EpochHelper
from orbit.utils.loop_fns import create_loop_fn
from orbit.utils.loop_fns import create_tf_while_loop_fn
from orbit.utils.loop_fns import LoopFnWithSummaries
from orbit.utils.summary_manager import SummaryManager
from orbit.utils.tpu_summaries import OptionalSummariesFunction
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some layered modules/functions to help users writing custom training loop."""
import inspect
import tensorflow as tf
def create_global_step() -> tf.Variable:
"""Creates a `tf.Variable` suitable for use as a global step counter.
Creating and managing a global step variable may be necessary for
`AbstractTrainer` subclasses that perform multiple parameter updates per
`Controller` "step", or use different optimizers on different steps.
In these cases, an `optimizer.iterations` property generally can't be used
directly, since it would correspond to parameter updates instead of iterations
in the `Controller`'s training loop. Such use cases should simply call
`step.assign_add(1)` at the end of each step.
Returns:
A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
first replica's value retained when synchronizing across replicas in
a distributed setting.
"""
return tf.Variable(
0,
dtype=tf.int64,
name="global_step",
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally have
an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`, except
that the `input_options` keyword is used to specify a
`tf.distribute.InputOptions` for making the distributed dataset.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
input_options = kwargs.pop("input_options", None)
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn,
input_options)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`.")
def dataset_fn(input_context):
"""Wraps `dataset_or_fn` for strategy.distribute_datasets_from_function."""
# If `dataset_or_fn` is a function and has an argument named
# `input_context`, pass through the given `input_context`. Otherwise
# `input_context` will be ignored.
argspec = inspect.getfullargspec(dataset_or_fn)
arg_names = argspec.args
if "input_context" in arg_names:
kwargs["input_context"] = input_context
return dataset_or_fn(*args, **kwargs)
return strategy.distribute_datasets_from_function(dataset_fn, input_options)
def get_value(x):
"""Returns input values, converting any TensorFlow values to NumPy values.
Args:
x: The input. May be a `tf.Tensor` or `tf.Variable`.
Returns:
If the input is a TensorFlow `Tensor`, returns the `Tensor`'s equivalent
NumPy value. Otherwise, just returns the input.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.utils.common."""
from orbit.utils import common
import tensorflow as tf
class UtilsTest(tf.test.TestCase):
def test_create_global_step(self):
step = common.create_global_step()
self.assertEqual(step.name, "global_step:0")
self.assertEqual(step.dtype, tf.int64)
self.assertEqual(step, 0)
step.assign_add(1)
self.assertEqual(step, 1)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a utility class for training in epochs."""
import tensorflow as tf
class EpochHelper:
"""A helper class handle bookkeeping of epochs in custom training loops."""
def __init__(self, epoch_steps: int, global_step: tf.Variable):
"""Initializes the `EpochHelper` instance.
Args:
epoch_steps: An integer indicating how many steps are in an epoch.
global_step: A `tf.Variable` providing the current global step.
"""
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
def epoch_end(self):
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch.")
current_step = self._global_step.numpy()
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment