Commit bfa9364a authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Factor out evaluate(), continuous_eval(), and continuous_train_and_eval() from...

Factor out evaluate(), continuous_eval(), and continuous_train_and_eval() from estimator_util.py into estimator_runner.py.

PiperOrigin-RevId: 212903406
parent 44f05013
......@@ -25,6 +25,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......@@ -37,6 +38,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
......@@ -86,7 +87,9 @@ def main(_):
# Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir.
estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
eval_steps = None # Evaluate over all examples in the file.
eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
estimator_runner.evaluate(estimator, eval_args)
if __name__ == "__main__":
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
......@@ -116,7 +117,7 @@ def main(_):
"val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps)
}
for _ in estimator_util.continuous_train_and_eval(
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_args=eval_args,
......
......@@ -32,6 +32,12 @@ py_test(
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
......
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training and evaluation using a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError):
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation.",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training models with the TensorFlow Estimator API."""
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
......@@ -202,129 +202,3 @@ def create_estimator(model_class,
params={"batch_size": hparams.batch_size})
return estimator
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment