Commit 5d09fb27 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 320461113
parent 8f1bce15
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
![TensorFlow Requirement: 2.x](https://img.shields.io/badge/TensorFlow%20Requirement-2.x-brightgreen)
# Orbit
Orbit is a customized training loop library built on top of Tensorflow 2. It
provides a flexible lightweight library that users can easily use or fork when
writing [customized training loop code](https://www.tensorflow.org/tutorials/distribute/custom_training)
in TF2. It intergates with `tf.distribute` seamlessly and supports running on
different device types (CPU, GPU, and TPU).
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from orbit import utils
from orbit.controller import Controller
from orbit.runner import *
from orbit.standard_runner import *
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time
from typing import Callable, Optional, Text, Union
from absl import logging
from orbit import runner
from orbit import utils
import tensorflow as tf
def _log_info(message: Text):
"""Logs `message` to the `info` log, and also prints to stdout."""
logging.info(message)
print(message)
def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name: str):
if interval and steps_per_loop and (interval % steps_per_loop != 0):
raise ValueError("The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
interval_name, interval, steps_per_loop))
class Controller(object):
"""Class that facilitates training and evaluation of models."""
def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
trainer: Optional[runner.AbstractTrainer] = None,
evaluator: Optional[runner.AbstractEvaluator] = None,
global_step: Optional[tf.Variable] = None,
# Train related
steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# Summary related
summary_interval: Optional[int] = None,
summary_dir: Optional[Text] = None,
# Evaluation related
eval_summary_dir: Optional[Text] = None):
"""Constructs a `Controller` instance.
Args:
strategy: An instance of `tf.distribute.Strategy`.
trainer: An instance of `orbit.AbstractTrainer`, which represents model
training details.
evaluator: An instance of `orbit.AbstractEvaluator`, which represents
model evaluation details.
global_step: An integer `tf.Variable` indicating the global training step
number. Usually this can be obtained from `iterations` property of the
model's optimizer (e.g. `self.optimizer.iterations`), or users can
create their own global step variable as well. If the users create their
own global step variable, it is recommended to create the `tf.Variable`
inside strategy scope, and with
`aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA`.
steps_per_loop: The number of steps to run in each "inner loop" of
training (passed to the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`.
summary_interval: Step interval for training summaries. Note that this
argument only applies to the summaries inside `trainer.train` function.
Summaries outside like "steps_per_second" and outputs from
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
Raises:
ValueError: If both `trainer` and `evaluator` are None.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `summary_interval` is not a positive integer or it cannot
be divisible by `steps_per_loop`.
"""
if trainer is None and evaluator is None:
raise ValueError("`trainer` and `evaluator` should not both be None")
if trainer is not None:
if steps_per_loop is None:
raise ValueError("`steps_per_loop` is required when `trainer` is "
"provided.")
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
raise ValueError("`steps_per_loop` should be a positive integer")
if summary_interval is not None:
if summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
_validate_interval(
summary_interval, steps_per_loop, interval_name="summary")
self.trainer = trainer
self.evaluator = evaluator
self.strategy = strategy or tf.distribute.get_strategy()
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
if summary_dir is None and checkpoint_manager:
summary_dir = checkpoint_manager.directory
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
eval_summary_writer = None
if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None:
# Reuse the summary writer if train and evaluation summary directory
# are the same.
self.eval_summary_manager = self.summary_manager
else:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
if self.global_step is not None:
tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed.
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None:
checkpoint_interval = self.checkpoint_manager.checkpoint_interval
_validate_interval(
checkpoint_interval, steps_per_loop, interval_name="checkpoint")
model_restored = self.restore_checkpoint()
if not model_restored and checkpoint_interval:
# If the model is not restored from a checkpoint, save an initial
# checkpoint.
self.save_checkpoint()
def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Runs training.
This method calls the `train` method on the Trainable object until the
global step count is equal to `steps`. It will optionally save checkpoints,
if a CheckpointManager was passed to the Controller instance's `__init__`.
Args:
steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method
returns. Defaults to True (write the checkpoint). This is always
triggered, regardless of the checkpointing interval.
"""
if self.trainer is None:
raise ValueError("`self.trainer` is required when calling `train` "
"method.")
if self.global_step is None:
raise ValueError("`self.global_step` is required when calling `train` "
"method.")
# TODO(momernick): Support steps=None or -1 (training to exhaustion).
current_step = self.global_step.numpy() # This is an expensive access.
while current_step < steps:
logging.info("Train at step %s of %s", current_step, steps)
# Calculates steps to run for the next train loop.
num_steps = min(steps - current_step, self.steps_per_loop)
self._train_n_steps(num_steps)
self._maybe_save_checkpoint()
current_step = self.global_step.numpy() # This is an expensive access.
if checkpoint_at_completion:
self.save_checkpoint()
def evaluate(self, steps: int = None):
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
steps, then writes the returned summaries (if any).
Args:
steps: The number of steps to evaluate for.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
"""
if self.evaluator is None:
raise ValueError("`evaluator` must be provided to call `evaluate()` "
"method.")
steps = steps or -1
current_step = self.global_step.numpy()
if steps > 0:
logging.info("Running %s steps of evaluation at train step: %s", steps,
current_step)
steps = tf.convert_to_tensor(steps, dtype=tf.int32)
else:
logging.info("Evaluating at train step: %s", current_step)
with self.eval_summary_manager.summary_writer.as_default():
eval_outputs = self.evaluator.evaluate(steps)
if eval_outputs:
eval_outputs = tf.nest.map_structure(utils.get_value, eval_outputs)
info = "step: {} evaluation metric: {}".format(
current_step, eval_outputs)
_log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()
def restore_checkpoint(self, checkpoint_path: Text = None):
"""Restore or initialize the model.
Args:
checkpoint_path: An optional string indicates the checkpoint path to
restore. If None, will restore from `self.checkpoint_manager`.
Returns:
The path to the restored checkpoint if a restore happened, or None
if no restore occurred.
"""
with self.strategy.scope():
# Checkpoint restoring should be inside scope. b/139450638
if checkpoint_path is not None:
self.checkpoint_manager.checkpoint.restore(checkpoint_path)
return checkpoint_path
return self.checkpoint_manager.restore_or_initialize()
def save_checkpoint(self):
"""Checkpoint the model.
This method will write a checkpoint containing the current state of the
model.
Raises:
ValueError: if no CheckpointManager was provided to this Controller's
init args.
"""
self._maybe_save_checkpoint(force_trigger=True)
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
"""Train and evaluate in an interleaved manner.
This method will train the model until the global step count equals
`train_steps`, running an evaluation for `eval_steps` every `eval_interval`
training steps. In addition, this method will run a final evaluation at the
end of the training sequence.
Args:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
_validate_interval(eval_interval, self.steps_per_loop, interval_name="eval")
current_step = self.global_step.numpy() # This is an expensive access.
eval_interval = eval_interval or (train_steps - current_step)
while current_step < train_steps:
interval = min(train_steps - current_step, eval_interval)
num_steps = current_step + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy() # This is an expensive access.
self.save_checkpoint()
def evaluate_continuously(self,
steps: int = None,
timeout: Optional[Union[int, float]] = None,
timeout_fn: Optional[Callable[[], bool]] = None):
"""Monitor a directory and evaluate on checkpoints in it.
This method continuously monitors a directory as specified by this
Controller's CheckpointManager init arg and runs evaluation on the
checkpoints found there.
Args:
steps: The number of steps to run when evaluating.
timeout: The maximum number of seconds to wait between checkpoints. See
tf.train.checkpoints_iterator documentation.
timeout_fn: Optional callable to call after a timeout. If the function
returns True, then it means that no new checkpoints will be generated
and the iterator will exit.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` was not provided as a controller init arg.
"""
for checkpoint_path in tf.train.checkpoints_iterator(
self.checkpoint_manager.directory,
timeout=timeout,
timeout_fn=timeout_fn):
self.restore_checkpoint(checkpoint_path)
self.evaluate(steps)
def _train_n_steps(self, num_steps: int):
"""Run training for `num_steps`.
It will also write training outputs to summaries if there is any.
Args:
num_steps: An integer indicates how many steps to run for this training
loop.
Raises:
RuntimeError: If `global_step` is not updated correctly in
`trainer.train`.
"""
if not self.step_timer:
self.step_timer = StepTimer(self.global_step)
# Calculates steps to run for the next train loop.
current_step = self.global_step.numpy()
logging.info("Entering training loop at step %s of %s", current_step,
num_steps)
current_step += num_steps
num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32)
with self.summary_manager.summary_writer.as_default():
# Create a lambda that returns true when summaries should be written.
should_record = False # Allows static optimization in no-summary cases.
if self.summary_interval:
should_record = lambda: (self.global_step % self.summary_interval == 0)
with tf.summary.record_if(should_record):
train_outputs = self.trainer.train(num_steps)
# Updates and verifies the current step after a training loop finishes.
if current_step != self.global_step.numpy():
raise RuntimeError("`trainer.train` function is not updating "
"`global_step` correctly, expected: %s, actual: %s" %
(current_step, self.global_step.numpy()))
# Print information like metrics and steps_per_second after a training
# loop.
if train_outputs:
train_outputs = tf.nest.map_structure(utils.get_value, train_outputs)
train_outputs = train_outputs or {}
steps_per_second = self.step_timer.steps_per_second()
info = "step: {} steps_per_second: {:.2f} {}".format(
current_step, steps_per_second, train_outputs)
_log_info(info)
train_outputs["steps_per_second"] = steps_per_second
self.summary_manager.write_summaries(train_outputs)
def _maybe_save_checkpoint(self, force_trigger: bool = False):
"""Save checkpoints if necessary.
Args:
force_trigger: A boolean indicates whether to force saving checkpoints
regardless of the checkpoint interval.
Returns:
A boolean indicating whether a checkpoint was saved.
"""
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=not force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoints in %s", ckpt_path)
return True
return False
class StepTimer(object):
"""Utility class for measuring steps/second."""
def __init__(self, step):
self.step = step
self.start()
def start(self):
self.last_iteration = self.step.numpy()
self.last_time = time.time()
def steps_per_second(self, restart=True):
value = ((self.step.numpy() - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
return value
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
from orbit import controller
from orbit import standard_runner
import tensorflow as tf
def create_model():
x = tf.keras.layers.Input(shape=(3,), name="input")
y = tf.keras.layers.Dense(4, name="dense")(x)
model = tf.keras.Model(x, y)
return model
def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
logging.info(event)
yield event.summary
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self, return_numpy=False):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
self.return_numpy = return_numpy
train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
eval_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(self, train_dataset)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
train_loss = self.train_loss.result()
return {
"loss": train_loss.numpy() if self.return_numpy else train_loss,
}
def build_eval_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_begin(self):
self.eval_loss.reset_states()
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
eval_loss = self.eval_loss.result()
return {
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
}
class TestEvaluator(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def eval_reduce(self, state, output):
state.append(output)
return state
def eval_begin(self):
return []
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
return loss
per_replica_losses = self.strategy.run(
_replicated_step, args=(next(iterator),))
mean_loss = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return mean_loss
def eval_end(self, outputs):
return {
"eval_loss": tf.reduce_mean(outputs),
}
class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
tf.summary.scalar("loss", loss)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runner = TestRunner()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runner.global_step.assign(0)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_no_checkpoint_and_summaries(self):
test_runner = TestRunner()
# No checkpoint + summary directories.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
@parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
def test_train_and_evaluate(self, return_numpy):
test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_evaluate_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=2)
# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
def timeout_fn():
with tf.io.gfile.GFile(done_file, "w") as f:
f.write("DONE")
return True
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate_continuously(
timeout=1, timeout_fn=timeout_fn, steps=2)
self.assertNotEmpty(tf.io.gfile.glob(done_file))
def test_no_eval_steps(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
def test_already_trained_model(self):
test_runner = TestRunner()
test_runner.global_step.assign(10)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
def test_summaries_inside_train_fn(self):
test_runner = TestTrainerWithSummaries()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_with_same_summary_dir(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("loss",
os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("eval_loss",
os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
class EarlyStopController(controller.Controller):
"""A subclass of Controller supports early stopping."""
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
while self.global_step.numpy() < train_steps:
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
logging.info(
"Training early stopped as eval_loss %s is less than 0.1",
test_runner.eval_loss.result())
return
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = EarlyStopController(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=6, eval_interval=2)
self.assertLess(test_runner.global_step, 10)
def test_evaluate_with_loss_outputs(self):
test_evaluator = TestEvaluator()
checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.model_dir, max_to_keep=None)
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=5)
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_reset_datasets(self):
test_runner = TestRunner()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
train_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
eval_dataset = (
test_runner.strategy.experimental_distribute_datasets_from_function(
dataset_fn))
test_runner.train_dataset = train_dataset
test_runner.eval_dataset = eval_dataset
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
if __name__ == "__main__":
tf.test.main()
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
from typing import Dict, Optional, Text
import six
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta)
class AbstractTrainer(tf.Module):
"""An abstract class defining the APIs required for training."""
@abc.abstractmethod
def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""Implements model training with multiple steps.
In training, it is common to break the total training steps into several
training loops, so users can do checkpointing, write summaries and run some
python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple
python loop to drive each step.
Args:
num_steps: A guideline for how many training steps to run. Note that it is
up to the model what constitutes a "step" (this may involve more than
one update to model parameters, e.g. if training a GAN).
Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
"""
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluator(tf.Module):
"""An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod
def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""Implements model evaluation.
Args:
num_steps: A guideline for how many evaluation steps to run. Note that it
is up to the model what constitutes a "step". Generally, it may be
desirable to support both a limited number of eval steps and iterating
over a full dataset (however many steps are required) when `num_steps`
is `None`.
Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
"""
pass
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
from typing import Any, Dict, Optional, Text
from orbit import runner
from orbit import utils
import six
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta)
class StandardTrainer(runner.AbstractTrainer):
"""Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self,
train_dataset,
use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
"""Construct a `StandardTrainer` object.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
if use_tf_while_loop and not use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
if use_tpu_summary_optimization and not use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = use_tf_while_loop
self._use_tf_function = use_tf_function
self._train_dataset = train_dataset
self._train_iter = None
self._train_loop_fn = None
self._use_tpu_summary_optimization = use_tpu_summary_optimization
def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
self.train_loop_begin()
if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
if self._train_loop_fn is None:
train_fn = self.train_step
if self._use_tf_while_loop:
self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
if self._use_tpu_summary_optimization:
self._train_loop_fn = utils.train_function_with_summaries(
self._train_loop_fn)
else:
self._train_loop_fn = tf.function(self._train_loop_fn)
else:
if self._use_tf_function:
train_fn = tf.function(train_fn)
self._train_loop_fn = utils.create_loop_fn(train_fn)
self._train_loop_fn(self._train_iter, num_steps)
return self.train_loop_end()
def train_loop_begin(self):
"""Called once at the beginning of the training loop.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over multiple
steps of training.
"""
pass
@abc.abstractmethod
def train_step(self, iterator):
"""Implements one step of training.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
"""
pass
def train_loop_end(self) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the training loop.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the train() method.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
@property
def train_dataset(self):
"""Returns the train_dataset instance."""
return self._train_dataset
@train_dataset.setter
def train_dataset(self, train_dataset):
"""Set a new train dataset and replace with the existing one.
Any unfinished work in the previous dataset will be discarded.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self._train_dataset = train_dataset
self._train_iter = None
@six.add_metaclass(abc.ABCMeta)
class StandardEvaluator(runner.AbstractEvaluator):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
"""
self._eval_use_tf_function = use_tf_function
self._eval_dataset = eval_dataset
self._eval_loop_fn = None
def evaluate(
self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
"""See base class."""
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
eval_iter = tf.nest.map_structure(iter, self._eval_dataset)
if self._eval_loop_fn is None:
eval_fn = self.eval_step
if self._eval_use_tf_function:
eval_fn = tf.function(eval_fn)
self._eval_loop_fn = utils.create_loop_fn(eval_fn)
outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None:
return self.eval_end()
else:
return self.eval_end(outputs)
def eval_begin(self) -> Any:
"""Called once at the beginning of the evaluation.
This method is called before dataset iterators creation.
This is a good place to reset metrics that accumulate values over the entire
evaluation.
Returns:
An output which is passed as `state` argument into `eval_reduce` function.
"""
pass
@abc.abstractmethod
def eval_step(self, iterator) -> Any:
"""Implements one step of evaluation.
What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.run`.
Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or
DistributedIterator.
Returns:
An output which is passed as `step_outputs` argument into `eval_reduce`
function.
"""
pass
def eval_end(self, *args) -> Optional[Dict[Text, tf.Tensor]]:
"""Called at the end of the evaluation.
This is a good place to get metric results. The value returned from this
function will be returned as-is from the evaluate() method.
Args:
*args: the outputs from `eval_reduce` for the last eval step.
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
"""
pass
def eval_reduce(self, state=None, step_outputs=None) -> Any:
"""A function to do the reduction on the evaluation outputs per step.
This is useful for passing states throughout evaluation. E.g. it can be used
to maintain the output losses from all the evaluation steps, and compute the
mean loss in `eval_end` function.
Args:
state: A maintained state throughout the evaluation.
step_outputs: Outputs from the current evaluation step.
Returns:
An output which is passed as `state` argument into `eval_reduce` function
for the next step. After evaluation is finished, the output from last step
will be passed into `eval_end` function.
"""
pass
@property
def eval_dataset(self):
"""Returns the train_datase instance."""
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, eval_dataset):
"""Set a new eval dataset and replace with the existing one.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
"""
self._eval_dataset = eval_dataset
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from orbit import standard_runner
import tensorflow as tf
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for tests."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable(
0,
trainable=False,
dtype=tf.int64,
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
def train_loop_begin(self):
self.train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
def train_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
return self.global_step.numpy()
def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_step(self, iterator):
def _replicated_step(_):
self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
return self.global_step.numpy()
class StandardRunnerTest(tf.test.TestCase):
def test_train(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_eval(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import contextlib
import functools
import inspect
import numpy as np
import six
import tensorflow as tf
def create_loop_fn(step_fn):
"""Creates a multiple steps function driven by the python while loop.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps==-1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
`value` is the outputs from `step_fn`.
Returns:
The updated state.
"""
try:
step = 0
# To make sure the OutOfRangeError exception can be handled well with
# async remote eager, we need to wrap the loop body in a `async_scope`.
with tf.experimental.async_scope():
while (num_steps == -1 or step < num_steps):
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Create a multiple steps function driven by tf.while_loop on the host.
Args:
step_fn: A function which takes `iterator` as input.
Returns:
A callable defined as the `loop_fn` defination below.
"""
def loop_fn(iterator, num_steps):
"""A loop function with multiple steps.
Args:
iterator: A nested structure of tf.data `Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Must be a tf.Tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
"may cause retracing.")
for _ in tf.range(num_steps):
step_fn(iterator)
return loop_fn
def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
"""A helper function to create distributed dataset.
Args:
strategy: An instance of `tf.distribute.Strategy`.
dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
`tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
it is a function, it could optionally have an argument named
`input_context` which is `tf.distribute.InputContext` argument type.
*args: The list of arguments to be passed to dataset_or_fn.
**kwargs: Any keyword arguments to be passed.
Returns:
A distributed Dataset.
"""
if strategy is None:
strategy = tf.distribute.get_strategy()
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
"of `tf.data.Dataset`")
def dataset_fn(ctx):
"""Wrapped dataset function for creating distributed dataset.."""
# If `dataset_or_fn` is a function and has `input_context` as argument
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if six.PY3:
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn) # pylint: disable=deprecated-method
args_names = argspec.args
if "input_context" in args_names:
kwargs["input_context"] = ctx
ds = dataset_or_fn(*args, **kwargs)
return ds
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object):
"""A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None):
"""Construct a summary manager object.
Args:
summary_dir: the directory to write summaries.
summary_fn: A callable defined as `def summary_fn(name, tensor,
step=None)`, which describes the summary operation.
global_step: A `tf.Variable` instance for the global step.
"""
self._enabled = (summary_dir is not None)
self._summary_dir = summary_dir
self._summary_fn = summary_fn
self._summary_writer = None
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
@property
def summary_writer(self):
"""Returns the underlying summary writer."""
if self._summary_writer is not None:
return self._summary_writer
if self._enabled:
self._summary_writer = tf.summary.create_file_writer(self._summary_dir)
else:
self._summary_writer = tf.summary.create_noop_writer()
return self._summary_writer
def flush(self):
"""Flush the underlying summary writer."""
if self._enabled:
tf.summary.flush(self.summary_writer)
def write_summaries(self, items):
"""Write a bulk of summaries.
Args:
items: a dictionary of `Tensors` for writing summaries.
"""
# TODO(rxsang): Support writing summaries with nested structure, so users
# can split the summaries into different directories for nicer visualization
# in Tensorboard, like train and eval metrics.
if not self._enabled:
return
with self.summary_writer.as_default():
for name, tensor in items.items():
self._summary_fn(name, tensor, step=self._global_step)
@six.add_metaclass(abc.ABCMeta)
class Trigger(object):
"""An abstract class representing a "trigger" for some event."""
@abc.abstractmethod
def __call__(self, value: float, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: Whether the trigger is forced triggered.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
@abc.abstractmethod
def reset(self):
"""Reset states in the trigger."""
class IntervalTrigger(Trigger):
"""Triggers on every fixed interval."""
def __init__(self, interval, start=0):
"""Constructs the IntervalTrigger.
Args:
interval: The triggering interval.
start: An initial value for the trigger.
"""
self._interval = interval
self._last_trigger_value = start
def __call__(self, value, force_trigger=False):
"""Maybe trigger the event based on the given value.
Args:
value: the value for triggering.
force_trigger: If True, the trigger will be forced triggered unless the
last trigger value is equal to `value`.
Returns:
`True` if the trigger is triggered on the given `value`, and
`False` otherwise.
"""
if force_trigger and value != self._last_trigger_value:
self._last_trigger_value = value
return True
if self._interval and self._interval > 0:
if value >= self._last_trigger_value + self._interval:
self._last_trigger_value = value
return True
return False
def reset(self):
"""See base class."""
self._last_trigger_value = 0
class EpochHelper(object):
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step):
"""Constructs the EpochHelper.
Args:
epoch_steps: An integer indicates how many steps in an epoch.
global_step: A `tf.Variable` instance indicates the current global step.
"""
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
def epoch_end(self):
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
@contextlib.contextmanager
def _soft_device_placement():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting = tf.config.get_soft_device_placement()
try:
tf.config.set_soft_device_placement(True)
yield
finally:
tf.config.set_soft_device_placement(original_setting)
def train_function_with_summaries(*args, **kwargs):
"""Utility function to support TPU summaries via multiple `tf.function`s.
This permits interleaving summaries inside TPU-compatible code, but without
any performance impact on steps that do not write summaries.
Usage is as a decorator, similar to `tf.function`, and any `tf.function`
arguments will be passed through if supplied:
@trainer.train_function_with_summaries
def train(self, num_steps):
...
The decorated function is assumed to be a loop method accepting a `num_steps`
parameter, as for instance would be called within the `Controller`'s outer
train loop. The implementation here assumes that `summary_frequency` is
divisible by `steps_per_loop`. The decorated method should accept two
arguments, `self` and `num_steps`.
Two `tf.function` versions of `train_fn` are created: one inside a summary
writer scope with soft device placement enabled (used on steps that require
summary writing), and one with no summary writer present and soft device
placement disabled (used on all other steps).
Args:
*args: Arguments to pass through to `tf.function`.
**kwargs: Keyword arguments to pass through to `tf.function`.
Returns:
If the first argument is a callable, returns the decorated callable.
Otherwise, returns a decorator.
"""
def decorator(train_fn):
# TODO(dhr): Validate the signature of train_fn?
train_fn_with_summaries = tf.function(train_fn, *args, **kwargs)
train_fn_without_summaries = tf.function(train_fn, *args, **kwargs)
@functools.wraps(train_fn)
def wrapper(self, num_steps):
if tf.summary.should_record_summaries():
with _soft_device_placement():
output = train_fn_with_summaries(self, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
with tf.summary.record_if(False):
output = train_fn_without_summaries(self, num_steps)
return output
return wrapper
if args and callable(args[0]):
train_fn, args = args[0], args[1:]
return decorator(train_fn)
return decorator
def get_value(x) -> np.ndarray:
"""Returns the value of a variable/tensor.
Args:
x: input variable.
Returns:
A Numpy array.
"""
if not tf.is_tensor(x):
return x
return x.numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment