"vscode:/vscode.git/clone" did not exist on "fda00bf7bfe0b6a3a67ae0274a892299fdbefc4f"
Commit 1b77cd80 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Enables timer callback and disables checkpoint saving in retinanet benchmark test.

PiperOrigin-RevId: 275080469
parent cb913691
......@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}]
if self.timer_callback:
metrics.append({
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
'name': 'exp_per_second',
'value': self.timer_callback.get_examples_per_sec(train_batch_size)
})
else:
metrics.append({
......@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def _run_detection_main(self):
"""Starts detection job."""
return detection.main('unused_argv')
return detection.run(callbacks=[self.timer_callback])
class RetinanetAccuracy(RetinanetBenchmarkBase):
......@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats=summary,
wall_time_sec=wall_time_sec,
min_ap=min_ap,
max_ap=max_ap)
max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])
def _setup(self):
super(RetinanetAccuracy, self)._setup()
......@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
......
......@@ -103,6 +103,8 @@ def initialize_common_flags():
flags.DEFINE_integer(
'task_index', 0,
'If multi-worker training, the task_index of this worker.')
flags.DEFINE_integer('save_checkpoint_freq', None,
'Number of steps to save checkpoint.')
def strategy_flags_dict():
......@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if save_config:
self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
......@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_steps:
if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
......
......@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS = flags.FLAGS
def run_executor(params, train_input_fn=None, eval_input_fn=None):
def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
model_builder = model_factory.model_generator(params)
......@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval':
......@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise ValueError('Mode not found: %s.' % FLAGS.mode)
def main(argv):
del argv # Unused.
def run(callbacks=None):
params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict(
......@@ -171,7 +173,16 @@ def main(argv):
batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples)
return run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
params,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
def main(argv):
del argv # Unused.
return run()
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment