Commit 1b77cd80 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Enables timer callback and disables checkpoint saving in retinanet benchmark test.

PiperOrigin-RevId: 275080469
parent cb913691
...@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark): ...@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}] }]
if self.timer_callback: if self.timer_callback:
metrics.append({ metrics.append({
'name': 'name': 'exp_per_second',
'exp_per_second', 'value': self.timer_callback.get_examples_per_sec(train_batch_size)
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}) })
else: else:
metrics.append({ metrics.append({
...@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def _run_detection_main(self): def _run_detection_main(self):
"""Starts detection job.""" """Starts detection job."""
return detection.main('unused_argv') return detection.run(callbacks=[self.timer_callback])
class RetinanetAccuracy(RetinanetBenchmarkBase): class RetinanetAccuracy(RetinanetBenchmarkBase):
...@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats=summary, stats=summary,
wall_time_sec=wall_time_sec, wall_time_sec=wall_time_sec,
min_ap=min_ap, min_ap=min_ap,
max_ap=max_ap) max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])
def _setup(self): def _setup(self):
super(RetinanetAccuracy, self)._setup() super(RetinanetAccuracy, self)._setup()
...@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params['eval']['eval_samples'] = 8 params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco') FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None: if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback') logging.error('Cannot measure performance without timer callback')
else: else:
......
...@@ -103,6 +103,8 @@ def initialize_common_flags(): ...@@ -103,6 +103,8 @@ def initialize_common_flags():
flags.DEFINE_integer( flags.DEFINE_integer(
'task_index', 0, 'task_index', 0,
'If multi-worker training, the task_index of this worker.') 'If multi-worker training, the task_index of this worker.')
flags.DEFINE_integer('save_checkpoint_freq', None,
'Number of steps to save checkpoint.')
def strategy_flags_dict(): def strategy_flags_dict():
...@@ -447,6 +449,12 @@ class DistributedExecutor(object): ...@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if save_config: if save_config:
self._save_config(model_dir) self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params params = self._params
strategy = self._strategy strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place # To reduce unnecessary send/receive input pipeline operation, we place
...@@ -540,9 +548,11 @@ class DistributedExecutor(object): ...@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps. # iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
# step of training. # step of training.
if current_step < total_steps: if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step: if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
......
...@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data') ...@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def run_executor(params, train_input_fn=None, eval_input_fn=None): def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user.""" """Runs Retinanet model on distribution strategy defined by the user."""
model_builder = model_factory.model_generator(params) model_builder = model_factory.model_generator(params)
...@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None): ...@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop=params.train.iterations_per_loop, iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps, total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(), init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True) save_config=True)
elif FLAGS.mode == 'eval': elif FLAGS.mode == 'eval':
...@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None): ...@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise ValueError('Mode not found: %s.' % FLAGS.mode) raise ValueError('Mode not found: %s.' % FLAGS.mode)
def main(argv): def run(callbacks=None):
del argv # Unused.
params = config_factory.config_generator(FLAGS.model) params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
...@@ -171,7 +173,16 @@ def main(argv): ...@@ -171,7 +173,16 @@ def main(argv):
batch_size=params.eval.batch_size, batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples) num_examples=params.eval.eval_samples)
return run_executor( return run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn) params,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
def main(argv):
del argv # Unused.
return run()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment