Commit 9ca59f8a authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 275867562
parent befbe0f9
...@@ -253,8 +253,7 @@ class DistributedExecutor(object): ...@@ -253,8 +253,7 @@ class DistributedExecutor(object):
logging.warning('model_dir is empty, so skip the save config.') logging.warning('model_dir is empty, so skip the save config.')
def _get_input_iterator( def _get_input_iterator(
self, input_fn: Callable[[Optional[params_dict.ParamsDict]], self, input_fn: Callable[..., tf.data.Dataset],
tf.data.Dataset],
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]: strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
"""Returns distributed dataset iterator. """Returns distributed dataset iterator.
...@@ -275,7 +274,7 @@ class DistributedExecutor(object): ...@@ -275,7 +274,7 @@ class DistributedExecutor(object):
return iter( return iter(
strategy.experimental_distribute_datasets_from_function(input_fn)) strategy.experimental_distribute_datasets_from_function(input_fn))
else: else:
input_data = input_fn(self._params) input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data)) return iter(strategy.experimental_distribute_dataset(input_data))
def _create_replicated_step(self, def _create_replicated_step(self,
......
...@@ -58,16 +58,12 @@ class InputFn(object): ...@@ -58,16 +58,12 @@ class InputFn(object):
self._parser_fn = factory.parser_generator(params, mode) self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset self._dataset_fn = tf.data.TFRecordDataset
def __call__(self, def __call__(self, ctx=None, batch_size: int = None):
params: params_dict.ParamsDict = None,
batch_size=None,
ctx=None):
"""Provides tf.data.Dataset object. """Provides tf.data.Dataset object.
Args: Args:
params: placeholder for model parameters.
batch_size: expected batch size input data.
ctx: context object. ctx: context object.
batch_size: expected batch size input data.
Returns: Returns:
tf.data.Dataset object. tf.data.Dataset object.
...@@ -96,6 +92,6 @@ class InputFn(object): ...@@ -96,6 +92,6 @@ class InputFn(object):
# Parses the fetched records to input tensors for model function. # Parses the fetched records to input tensors for model function.
dataset = dataset.map(self._parser_fn, num_parallel_calls=64) dataset = dataset.map(self._parser_fn, num_parallel_calls=64)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -51,6 +51,9 @@ flags.DEFINE_string('training_file_pattern', None, ...@@ -51,6 +51,9 @@ flags.DEFINE_string('training_file_pattern', None,
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data') flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
flags.DEFINE_string(
'checkpoint_path', None,
'The checkpoint path to eval. Only used in eval_once mode.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -71,8 +74,11 @@ def run_executor(params, ...@@ -71,8 +74,11 @@ def run_executor(params,
builder = executor.ExecutorBuilder( builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type, strategy_type=params.strategy_type,
strategy_config=params.strategy_config) strategy_config=params.strategy_config)
num_workers = (builder.strategy.num_replicas_in_sync + 7) / 8 num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (num_workers > 1) is_multi_host = (int(num_workers) >= 2)
logging.info(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host: if is_multi_host:
train_input_fn = functools.partial( train_input_fn = functools.partial(
train_input_fn, train_input_fn,
...@@ -97,7 +103,7 @@ def run_executor(params, ...@@ -97,7 +103,7 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(), init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks, custom_callbacks=callbacks,
save_config=True) save_config=True)
elif FLAGS.mode == 'eval': elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
def _model_fn(params): def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
...@@ -105,22 +111,47 @@ def run_executor(params, ...@@ -105,22 +111,47 @@ def run_executor(params,
builder = executor.ExecutorBuilder( builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type, strategy_type=params.strategy_type,
strategy_config=params.strategy_config) strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if is_multi_host:
eval_input_fn = functools.partial(
eval_input_fn,
batch_size=params.eval.batch_size //
builder.strategy.num_replicas_in_sync)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers,
is_multi_host)
dist_executor = builder.build_executor( dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor, class_ctor=DetectionDistributedExecutor,
params=params, params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn, model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn, loss_fn=model_builder.build_loss_fn,
predict_post_process_fn=model_builder.post_processing, predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn()) .make_filter_trainable_variables_fn())
results = dist_executor.evaluate_from_model_dir( if FLAGS.mode == 'eval':
model_dir=params.model_dir, results = dist_executor.evaluate_from_model_dir(
eval_input_fn=eval_input_fn, model_dir=params.model_dir,
eval_metric_fn=model_builder.eval_metrics, eval_input_fn=eval_input_fn,
eval_timeout=params.eval.eval_timeout, eval_metric_fn=model_builder.eval_metrics,
min_eval_interval=params.eval.min_eval_interval, eval_timeout=params.eval.eval_timeout,
total_steps=params.train.total_steps) min_eval_interval=params.eval.min_eval_interval,
total_steps=params.train.total_steps)
else:
# Run evaluation once for a single checkpoint.
if not FLAGS.checkpoint_path:
raise ValueError('FLAGS.checkpoint_path cannot be empty.')
checkpoint_path = FLAGS.checkpoint_path
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
results, _ = dist_executor.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=model_builder.eval_metrics,
summary_writer=summary_writer)
for k, v in results.items(): for k, v in results.items():
logging.info('Final eval metric %s: %f', k, v) logging.info('Final eval metric %s: %f', k, v)
return results return results
...@@ -182,7 +213,7 @@ def run(callbacks=None): ...@@ -182,7 +213,7 @@ def run(callbacks=None):
def main(argv): def main(argv):
del argv # Unused. del argv # Unused.
return run() run()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -60,8 +60,7 @@ class COCOMetrics(object): ...@@ -60,8 +60,7 @@ class COCOMetrics(object):
return self._evaluator.evaluate() return self._evaluator.evaluate()
def reset_states(self): def reset_states(self):
logging.info('State is reset on calling metric.result().') return self._evaluator.reset()
pass
class RetinanetModel(base_model.Model): class RetinanetModel(base_model.Model):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment