Commit a60dd985 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Refactor model_tpu_main.py files and move continuous eval loop into model_lib.py

PiperOrigin-RevId: 192512429
parent f98f000a
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
import os
import tensorflow as tf import tensorflow as tf
...@@ -574,6 +575,48 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -574,6 +575,48 @@ def create_train_and_eval_specs(train_input_fn,
return train_spec, eval_specs return train_spec, eval_specs
def continuous_eval(estimator, model_dir, input_fn, eval_steps, train_steps,
name):
"""Perform continuous evaluation on checkpoints written to a model directory.
Args:
estimator: Estimator object to use for evaluation.
model_dir: Model directory to read checkpoints for continuous evaluation.
input_fn: Input function to use for evaluation.
eval_steps: Number of steps to run during each evaluation.
train_steps: Number of training steps. This is used to infer the last
checkpoint and stop evaluation loop.
name: Namescope for eval summary.
"""
def terminate_eval():
tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
return True
for ckpt in tf.contrib.training.checkpoints_iterator(
model_dir, min_interval_secs=180, timeout=None,
timeout_fn=terminate_eval):
tf.logging.info('Starting Evaluation.')
try:
eval_results = estimator.evaluate(
input_fn=input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name=name)
tf.logging.info('Eval results: %s' % eval_results)
# Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split('-')[1])
if current_step >= train_steps:
tf.logging.info(
'Evaluation finished after training step %d' % current_step)
break
except tf.errors.NotFoundError:
tf.logging.info(
'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
def populate_experiment(run_config, def populate_experiment(run_config,
hparams, hparams,
pipeline_config_path, pipeline_config_path,
......
...@@ -22,12 +22,10 @@ from __future__ import absolute_import ...@@ -22,12 +22,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.training.python.training import evaluation
from object_detection import model_hparams from object_detection import model_hparams
from object_detection import model_lib from object_detection import model_lib
...@@ -48,14 +46,7 @@ flags.DEFINE_string( ...@@ -48,14 +46,7 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'tpu_name', 'tpu_name',
default=None, default=None,
help='Name of the Cloud TPU for Cluster Resolvers. You must specify either ' help='Name of the Cloud TPU for Cluster Resolvers.')
'this flag or --master.')
flags.DEFINE_string(
'master',
default=None,
help='GRPC URL of the master (e.g. grpc://ip.address.of.tpu:8470). You '
'must specify either this flag or --tpu_name.')
flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).') flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).')
flags.DEFINE_integer('iterations_per_loop', 100, flags.DEFINE_integer('iterations_per_loop', 100,
...@@ -63,16 +54,10 @@ flags.DEFINE_integer('iterations_per_loop', 100, ...@@ -63,16 +54,10 @@ flags.DEFINE_integer('iterations_per_loop', 100,
# For mode=train_and_eval, evaluation occurs after training is finished. # For mode=train_and_eval, evaluation occurs after training is finished.
# Note: independently of steps_per_checkpoint, estimator will save the most # Note: independently of steps_per_checkpoint, estimator will save the most
# recent checkpoint every 10 minutes by default for train_and_eval # recent checkpoint every 10 minutes by default for train_and_eval
flags.DEFINE_string('mode', 'train_and_eval', flags.DEFINE_string('mode', 'train',
'Mode to run: train, eval, train_and_eval') 'Mode to run: train, eval')
flags.DEFINE_integer('train_batch_size', 32 * 8, 'Batch size for training.') flags.DEFINE_integer('train_batch_size', 32 * 8, 'Batch size for training.')
# For EVAL.
flags.DEFINE_integer('min_eval_interval_secs', 180,
'Minimum seconds between evaluations.')
flags.DEFINE_integer(
'eval_timeout_secs', None,
'Maximum seconds between checkpoints before evaluation terminates.')
flags.DEFINE_string( flags.DEFINE_string(
'hparams_overrides', None, 'Comma-separated list of ' 'hparams_overrides', None, 'Comma-separated list of '
'hyperparameters to override defaults.') 'hyperparameters to override defaults.')
...@@ -93,21 +78,12 @@ def main(unused_argv): ...@@ -93,21 +78,12 @@ def main(unused_argv):
flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('pipeline_config_path')
if FLAGS.master is None and FLAGS.tpu_name is None: tpu_cluster_resolver = (
raise RuntimeError('You must specify either --master or --tpu_name.') tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
if FLAGS.master is not None: zone=FLAGS.tpu_zone,
if FLAGS.tpu_name is not None: project=FLAGS.gcp_project))
tf.logging.warn('Both --master and --tpu_name are set. Ignoring ' tpu_grpc_url = tpu_cluster_resolver.get_master()
'--tpu_name and using --master.')
tpu_grpc_url = FLAGS.master
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
tpu_names=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
config = tpu_config.RunConfig( config = tpu_config.RunConfig(
master=tpu_grpc_url, master=tpu_grpc_url,
...@@ -134,53 +110,19 @@ def main(unused_argv): ...@@ -134,53 +110,19 @@ def main(unused_argv):
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps'] eval_steps = train_and_eval_dict['eval_steps']
if FLAGS.mode in ['train', 'train_and_eval']: if FLAGS.mode == 'train':
estimator.train(input_fn=train_input_fn, max_steps=train_steps) estimator.train(input_fn=train_input_fn, max_steps=train_steps)
if FLAGS.mode == 'train_and_eval':
# Eval one time.
eval_results = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
tf.logging.info('Eval results: %s' % eval_results)
# Continuously evaluating. # Continuously evaluating.
if FLAGS.mode == 'eval': if FLAGS.mode == 'eval':
def terminate_eval(): if FLAGS.eval_training_data:
tf.logging.info('Terminating eval after %d seconds of no checkpoints' % name = 'training_data'
FLAGS.eval_timeout_secs) input_fn = eval_on_train_input_fn
return True else:
name = 'validation_data'
# Run evaluation when there's a new checkpoint. input_fn = eval_input_fn
for ckpt in evaluation.checkpoints_iterator( model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, eval_steps,
FLAGS.model_dir, train_steps, name)
min_interval_secs=FLAGS.min_eval_interval_secs,
timeout=FLAGS.eval_timeout_secs,
timeout_fn=terminate_eval):
tf.logging.info('Starting to evaluate.')
if FLAGS.eval_training_data:
name = 'training_data'
input_fn = eval_on_train_input_fn
else:
name = 'validation_data'
input_fn = eval_input_fn
try:
eval_results = estimator.evaluate(
input_fn=input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name=name)
tf.logging.info('Eval results: %s' % eval_results)
# Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split('-')[1])
if current_step >= train_steps:
tf.logging.info(
'Evaluation finished after training step %d' % current_step)
break
except tf.errors.NotFoundError:
tf.logging.info(
'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment