Commit 9ca59f8a authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 275867562
parent befbe0f9
......@@ -253,8 +253,7 @@ class DistributedExecutor(object):
logging.warning('model_dir is empty, so skip the save config.')
def _get_input_iterator(
self, input_fn: Callable[[Optional[params_dict.ParamsDict]],
tf.data.Dataset],
self, input_fn: Callable[..., tf.data.Dataset],
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
"""Returns distributed dataset iterator.
......@@ -275,7 +274,7 @@ class DistributedExecutor(object):
return iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
else:
input_data = input_fn(self._params)
input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data))
def _create_replicated_step(self,
......
......@@ -58,16 +58,12 @@ class InputFn(object):
self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset
def __call__(self,
params: params_dict.ParamsDict = None,
batch_size=None,
ctx=None):
def __call__(self, ctx=None, batch_size: int = None):
"""Provides tf.data.Dataset object.
Args:
params: placeholder for model parameters.
batch_size: expected batch size input data.
ctx: context object.
batch_size: expected batch size input data.
Returns:
tf.data.Dataset object.
......@@ -96,6 +92,6 @@ class InputFn(object):
# Parses the fetched records to input tensors for model function.
dataset = dataset.map(self._parser_fn, num_parallel_calls=64)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......@@ -51,6 +51,9 @@ flags.DEFINE_string('training_file_pattern', None,
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
flags.DEFINE_string(
'checkpoint_path', None,
'The checkpoint path to eval. Only used in eval_once mode.')
FLAGS = flags.FLAGS
......@@ -71,8 +74,11 @@ def run_executor(params,
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = (builder.strategy.num_replicas_in_sync + 7) / 8
is_multi_host = (num_workers > 1)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
logging.info(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
......@@ -97,7 +103,7 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval':
elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
......@@ -105,22 +111,47 @@ def run_executor(params,
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if is_multi_host:
eval_input_fn = functools.partial(
eval_input_fn,
batch_size=params.eval.batch_size //
builder.strategy.num_replicas_in_sync)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers,
is_multi_host)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
eval_timeout=params.eval.eval_timeout,
min_eval_interval=params.eval.min_eval_interval,
total_steps=params.train.total_steps)
if FLAGS.mode == 'eval':
results = dist_executor.evaluate_from_model_dir(
model_dir=params.model_dir,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
eval_timeout=params.eval.eval_timeout,
min_eval_interval=params.eval.min_eval_interval,
total_steps=params.train.total_steps)
else:
# Run evaluation once for a single checkpoint.
if not FLAGS.checkpoint_path:
raise ValueError('FLAGS.checkpoint_path cannot be empty.')
checkpoint_path = FLAGS.checkpoint_path
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
results, _ = dist_executor.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
summary_writer=summary_writer)
for k, v in results.items():
logging.info('Final eval metric %s: %f', k, v)
return results
......@@ -182,7 +213,7 @@ def run(callbacks=None):
def main(argv):
del argv # Unused.
return run()
run()
if __name__ == '__main__':
......
......@@ -60,8 +60,7 @@ class COCOMetrics(object):
return self._evaluator.evaluate()
def reset_states(self):
logging.info('State is reset on calling metric.result().')
pass
return self._evaluator.reset()
class RetinanetModel(base_model.Model):
......
......@@ -16,6 +16,7 @@
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment