Commit 1b77cd80 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Enables timer callback and disables checkpoint saving in retinanet benchmark test.

PiperOrigin-RevId: 275080469
parent cb913691
......@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}]
if self.timer_callback:
metrics.append({
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
'name': 'exp_per_second',
'value': self.timer_callback.get_examples_per_sec(train_batch_size)
})
else:
metrics.append({
......@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def _run_detection_main(self):
"""Starts detection job."""
return detection.main('unused_argv')
return detection.run(callbacks=[self.timer_callback])
class RetinanetAccuracy(RetinanetBenchmarkBase):
......@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats=summary,
wall_time_sec=wall_time_sec,
min_ap=min_ap,
max_ap=max_ap)
max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])
def _setup(self):
super(RetinanetAccuracy, self)._setup()
......@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
......
......@@ -103,6 +103,8 @@ def initialize_common_flags():
flags.DEFINE_integer(
'task_index', 0,
'If multi-worker training, the task_index of this worker.')
flags.DEFINE_integer('save_checkpoint_freq', None,
'Number of steps to save checkpoint.')
def strategy_flags_dict():
......@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if save_config:
self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
......@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_steps:
if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
......
......@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS = flags.FLAGS
def run_executor(params, train_input_fn=None, eval_input_fn=None):
def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
model_builder = model_factory.model_generator(params)
......@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval':
......@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise ValueError('Mode not found: %s.' % FLAGS.mode)
def main(argv):
del argv # Unused.
def run(callbacks=None):
params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict(
......@@ -171,7 +173,16 @@ def main(argv):
batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples)
return run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
params,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
def main(argv):
del argv # Unused.
return run()
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment