Commit 8793267f authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Support evaluating over multiple datasets.

PiperOrigin-RevId: 205168785
parent 75d592e9
...@@ -207,7 +207,7 @@ the second deepest transits). ...@@ -207,7 +207,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in with training data in
[TFRecord](https://www.tensorflow.org/guide/datasets) format. The [TFRecord](https://www.tensorflow.org/programmers_guide/datasets) format. The
TFRecord format consists of a set of sharded files containing serialized TFRecord format consists of a set of sharded files containing serialized
`tf.Example` [protocol buffers](https://developers.google.com/protocol-buffers/). `tf.Example` [protocol buffers](https://developers.google.com/protocol-buffers/).
...@@ -343,7 +343,7 @@ bazel-bin/astronet/train \ ...@@ -343,7 +343,7 @@ bazel-bin/astronet/train \
--model_dir=${MODEL_DIR} --model_dir=${MODEL_DIR}
``` ```
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)
server in a separate process for real-time server in a separate process for real-time
monitoring of training progress and evaluation metrics. monitoring of training progress and evaluation metrics.
......
...@@ -112,11 +112,14 @@ def main(_): ...@@ -112,11 +112,14 @@ def main(_):
file_pattern=FLAGS.eval_files, file_pattern=FLAGS.eval_files,
input_config=config.inputs, input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL) mode=tf.estimator.ModeKeys.EVAL)
eval_args = {
"val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps)
}
for _ in estimator_util.continuous_train_and_eval( for _ in estimator_util.continuous_train_and_eval(
estimator=estimator, estimator=estimator,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_args=eval_args,
train_steps=FLAGS.train_steps): train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each # continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here. # training epoch. We don't do anything here.
......
...@@ -204,94 +204,106 @@ def create_estimator(model_class, ...@@ -204,94 +204,106 @@ def create_estimator(model_class,
return estimator return estimator
def evaluate(estimator, input_fn, eval_steps=None, eval_name="val"): def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint. """Runs evaluation on the latest model checkpoint.
Args: Args:
estimator: Instance of tf.Estimator. estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels). eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_steps: The number of steps for which to evaluate the model. If None, eval_name is the name of the evaluation set (e.g. "train" or "val"),
evaluates until input_fn raises an end-of-input exception. input_fn is an input function returning a tuple (features, labels), and
eval_name: Name of the evaluation set, e.g. "train" or "val". eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns: Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the global_step: The global step of the checkpoint evaluated.
training job has not yet saved a checkpoint or the checkpoint is deleted by values: A dict of metric values from the evaluation. May be empty, e.g. if
the time the TPU worker initializes. the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
""" """
values = {} # Default return value if evaluation fails. # Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir) latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint: if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint. # This is expected if the training job has not yet saved a checkpoint.
return values return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint) tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try: try:
values = estimator.evaluate(input_fn, steps=eval_steps, name=eval_name) for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except tf.errors.NotFoundError: except tf.errors.NotFoundError:
# Expected under some conditions, e.g. TPU worker does not finish # Expected under some conditions, e.g. checkpoint is already deleted by the
# initializing until long after the CPU job tells it to start evaluating # trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# and the checkpoint file is deleted already. # in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation", tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint) latest_checkpoint)
return values return global_step, values
def continuous_eval(estimator, def continuous_eval(estimator,
input_fn, eval_args,
train_steps=None, train_steps=None,
eval_steps=None, timeout_secs=None,
eval_name="val"): timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint. """Runs evaluation whenever there's a new checkpoint.
Args: Args:
estimator: Instance of tf.Estimator. estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels). eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training. If None, this will terminate once the model has finished training.
function will run forever. timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
eval_steps: The number of steps for which to evaluate the model. If None, indefinitely.
evaluates until input_fn raises an end-of-input exception. timeout_fn: Optional function to call after timeout. The iterator will exit
eval_name: Name of the evaluation set, e.g. "train" or "val". if and only if the function returns True.
Yields: Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes. the time the TPU worker initializes.
""" """
for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir): for _ in tf.contrib.training.checkpoints_iterator(
values = evaluate(estimator, input_fn, eval_steps, eval_name) estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
yield values global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = values.get("global_step", 0) global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps: if train_steps and global_step >= train_steps:
break break
def continuous_train_and_eval(estimator, def continuous_train_and_eval(estimator,
train_input_fn, train_input_fn,
eval_input_fn, eval_args,
local_eval_frequency=None, local_eval_frequency=None,
train_hooks=None, train_hooks=None,
train_steps=None, train_steps=None):
eval_steps=None,
eval_name="val"):
"""Alternates training and evaluation. """Alternates training and evaluation.
Args: Args:
estimator: Instance of tf.Estimator. estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels). train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels). eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception. None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call. inside the training call.
train_steps: The total number of steps to train the model for. train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields: Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the A dict of metric values from each evaluation. May be empty, e.g. if the
...@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator, ...@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator,
while True: while True:
# We run evaluation before training in this loop to prevent evaluation from # We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted. # being skipped if the process is interrupted.
values = evaluate(estimator, eval_input_fn, eval_steps, eval_name) global_step, values = evaluate(estimator, eval_args)
yield values yield global_step, values
global_step = values.get("global_step", 0) global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps: if train_steps and global_step >= train_steps:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment