".github/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3665d65be72ab15b68ba5629c2af20703288d502"
Commit 0d968ea2 authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307436543
parent 42919740
...@@ -69,10 +69,12 @@ FLAGS = flags.FLAGS ...@@ -69,10 +69,12 @@ FLAGS = flags.FLAGS
def run_executor(params, def run_executor(params,
mode,
checkpoint_path=None,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None, eval_input_fn=None,
callbacks=None, callbacks=None,
strategy=None): prebuilt_strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user.""" """Runs Retinanet model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16: if params.architecture.use_bfloat16:
...@@ -82,7 +84,9 @@ def run_executor(params, ...@@ -82,7 +84,9 @@ def run_executor(params,
model_builder = model_factory.model_generator(params) model_builder = model_factory.model_generator(params)
if strategy is None: if prebuilt_strategy is not None:
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts, distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index) strategy_config.task_index)
...@@ -96,7 +100,7 @@ def run_executor(params, ...@@ -96,7 +100,7 @@ def run_executor(params,
num_workers = int(strategy.num_replicas_in_sync + 7) // 8 num_workers = int(strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2) is_multi_host = (int(num_workers) >= 2)
if FLAGS.mode == 'train': if mode == 'train':
def _model_fn(params): def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN) return model_builder.build_model(params, mode=ModeKeys.TRAIN)
...@@ -128,8 +132,7 @@ def run_executor(params, ...@@ -128,8 +132,7 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(), init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks, custom_callbacks=callbacks,
save_config=True) save_config=True)
elif mode == 'eval' or mode == 'eval_once':
elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
def _model_fn(params): def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
...@@ -152,7 +155,7 @@ def run_executor(params, ...@@ -152,7 +155,7 @@ def run_executor(params,
trainable_variables_filter=model_builder trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn()) .make_filter_trainable_variables_fn())
if FLAGS.mode == 'eval': if mode == 'eval':
results = dist_executor.evaluate_from_model_dir( results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir, model_dir=params.model_dir,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
...@@ -162,9 +165,8 @@ def run_executor(params, ...@@ -162,9 +165,8 @@ def run_executor(params,
total_steps=params.train.total_steps) total_steps=params.train.total_steps)
else: else:
# Run evaluation once for a single checkpoint. # Run evaluation once for a single checkpoint.
if not FLAGS.checkpoint_path: if not checkpoint_path:
raise ValueError('FLAGS.checkpoint_path cannot be empty.') raise ValueError('checkpoint_path cannot be empty.')
checkpoint_path = FLAGS.checkpoint_path
if tf.io.gfile.isdir(checkpoint_path): if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
summary_writer = executor.SummaryWriter(params.model_dir, 'eval') summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
...@@ -177,7 +179,7 @@ def run_executor(params, ...@@ -177,7 +179,7 @@ def run_executor(params,
logging.info('Final eval metric %s: %f', k, v) logging.info('Final eval metric %s: %f', k, v)
return results return results
else: else:
raise ValueError('Mode not found: %s.' % FLAGS.mode) raise ValueError('Mode not found: %s.' % mode)
def run(callbacks=None): def run(callbacks=None):
...@@ -239,6 +241,8 @@ def run(callbacks=None): ...@@ -239,6 +241,8 @@ def run(callbacks=None):
return run_executor( return run_executor(
params, params,
FLAGS.mode,
checkpoint_path=FLAGS.checkpoint_path,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
callbacks=callbacks) callbacks=callbacks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment