Commit bfa9364a authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Factor out evaluate(), continuous_eval(), and continuous_train_and_eval() from...

Factor out evaluate(), continuous_eval(), and continuous_train_and_eval() from estimator_util.py into estimator_runner.py.

PiperOrigin-RevId: 212903406
parent 44f05013
...@@ -25,6 +25,7 @@ py_binary( ...@@ -25,6 +25,7 @@ py_binary(
":models", ":models",
"//astronet/util:config_util", "//astronet/util:config_util",
"//astronet/util:configdict", "//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util", "//astronet/util:estimator_util",
], ],
) )
...@@ -37,6 +38,7 @@ py_binary( ...@@ -37,6 +38,7 @@ py_binary(
":models", ":models",
"//astronet/util:config_util", "//astronet/util:config_util",
"//astronet/util:configdict", "//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util", "//astronet/util:estimator_util",
], ],
) )
......
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models from astronet import models
from astronet.util import config_util from astronet.util import config_util
from astronet.util import configdict from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util from astronet.util import estimator_util
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -86,7 +87,9 @@ def main(_): ...@@ -86,7 +87,9 @@ def main(_):
# Run evaluation. This will log the result to stderr and also write a summary # Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir. # file in the model_dir.
estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name) eval_steps = None # Evaluate over all examples in the file.
eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
estimator_runner.evaluate(estimator, eval_args)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models from astronet import models
from astronet.util import config_util from astronet.util import config_util
from astronet.util import configdict from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util from astronet.util import estimator_util
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -116,7 +117,7 @@ def main(_): ...@@ -116,7 +117,7 @@ def main(_):
"val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps) "val": (eval_input_fn, None) # eval_name: (input_fn, eval_steps)
} }
for _ in estimator_util.continuous_train_and_eval( for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator, estimator=estimator,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_args=eval_args, eval_args=eval_args,
......
...@@ -32,6 +32,12 @@ py_test( ...@@ -32,6 +32,12 @@ py_test(
deps = [":config_util"], deps = [":config_util"],
) )
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library( py_library(
name = "estimator_util", name = "estimator_util",
srcs = ["estimator_util.py"], srcs = ["estimator_util.py"],
......
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training and evaluation using a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError):
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation.",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Functions for training models with the TensorFlow Estimator API.""" """Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -202,129 +202,3 @@ def create_estimator(model_class, ...@@ -202,129 +202,3 @@ def create_estimator(model_class,
params={"batch_size": hparams.batch_size}) params={"batch_size": hparams.batch_size})
return estimator return estimator
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment