Commit 0d968ea2 authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307436543
parent 42919740
......@@ -69,10 +69,12 @@ FLAGS = flags.FLAGS
def run_executor(params,
mode,
checkpoint_path=None,
train_input_fn=None,
eval_input_fn=None,
callbacks=None,
strategy=None):
prebuilt_strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16:
......@@ -82,7 +84,9 @@ def run_executor(params,
model_builder = model_factory.model_generator(params)
if strategy is None:
if prebuilt_strategy is not None:
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
......@@ -96,7 +100,7 @@ def run_executor(params,
num_workers = int(strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if FLAGS.mode == 'train':
if mode == 'train':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN)
......@@ -128,8 +132,7 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
elif mode == 'eval' or mode == 'eval_once':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
......@@ -152,7 +155,7 @@ def run_executor(params,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
if FLAGS.mode == 'eval':
if mode == 'eval':
results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir,
eval_input_fn=eval_input_fn,
......@@ -162,9 +165,8 @@ def run_executor(params,
total_steps=params.train.total_steps)
else:
# Run evaluation once for a single checkpoint.
if not FLAGS.checkpoint_path:
raise ValueError('FLAGS.checkpoint_path cannot be empty.')
checkpoint_path = FLAGS.checkpoint_path
if not checkpoint_path:
raise ValueError('checkpoint_path cannot be empty.')
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
......@@ -177,7 +179,7 @@ def run_executor(params,
logging.info('Final eval metric %s: %f', k, v)
return results
else:
raise ValueError('Mode not found: %s.' % FLAGS.mode)
raise ValueError('Mode not found: %s.' % mode)
def run(callbacks=None):
......@@ -239,6 +241,8 @@ def run(callbacks=None):
return run_executor(
params,
FLAGS.mode,
checkpoint_path=FLAGS.checkpoint_path,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment