Commit 8793267f authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Support evaluating over multiple datasets.

PiperOrigin-RevId: 205168785
parent 75d592e9
......@@ -207,7 +207,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in
[TFRecord](https://www.tensorflow.org/guide/datasets) format. The
[TFRecord](https://www.tensorflow.org/programmers_guide/datasets) format. The
TFRecord format consists of a set of sharded files containing serialized
`tf.Example` [protocol buffers](https://developers.google.com/protocol-buffers/).
......@@ -343,7 +343,7 @@ bazel-bin/astronet/train \
--model_dir=${MODEL_DIR}
```
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)
server in a separate process for real-time
monitoring of training progress and evaluation metrics.
......
......@@ -112,11 +112,14 @@ def main(_):
file_pattern=FLAGS.eval_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL)
eval_args = {
"val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps)
}
for _ in estimator_util.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
......
......@@ -204,94 +204,106 @@ def create_estimator(model_class,
return estimator
def evaluate(estimator, input_fn, eval_steps=None, eval_name="val"):
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
values = {} # Default return value if evaluation fails.
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return values
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
values = estimator.evaluate(input_fn, steps=eval_steps, name=eval_name)
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. TPU worker does not finish
# initializing until long after the CPU job tells it to start evaluating
# and the checkpoint file is deleted already.
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint)
return values
return global_step, values
def continuous_eval(estimator,
input_fn,
eval_args,
train_steps=None,
eval_steps=None,
eval_name="val"):
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training. If None, this
function will run forever.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir):
values = evaluate(estimator, input_fn, eval_steps, eval_name)
yield values
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = values.get("global_step", 0)
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None,
eval_steps=None,
eval_name="val"):
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
......@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator,
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
values = evaluate(estimator, eval_input_fn, eval_steps, eval_name)
yield values
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = values.get("global_step", 0)
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment