"scripts/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "04812de2170336a7423987fadb1d923923aaa9f4"
Commit 1b77cd80 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Enables timer callback and disables checkpoint saving in retinanet benchmark test.

PiperOrigin-RevId: 275080469
parent cb913691
...@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark): ...@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}] }]
if self.timer_callback: if self.timer_callback:
metrics.append({ metrics.append({
'name': 'name': 'exp_per_second',
'exp_per_second', 'value': self.timer_callback.get_examples_per_sec(train_batch_size)
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}) })
else: else:
metrics.append({ metrics.append({
...@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def _run_detection_main(self): def _run_detection_main(self):
"""Starts detection job.""" """Starts detection job."""
return detection.main('unused_argv') return detection.run(callbacks=[self.timer_callback])
class RetinanetAccuracy(RetinanetBenchmarkBase): class RetinanetAccuracy(RetinanetBenchmarkBase):
...@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats=summary, stats=summary,
wall_time_sec=wall_time_sec, wall_time_sec=wall_time_sec,
min_ap=min_ap, min_ap=min_ap,
max_ap=max_ap) max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])
def _setup(self): def _setup(self):
super(RetinanetAccuracy, self)._setup() super(RetinanetAccuracy, self)._setup()
...@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params['eval']['eval_samples'] = 8 params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco') FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None: if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback') logging.error('Cannot measure performance without timer callback')
else: else:
......
...@@ -103,6 +103,8 @@ def initialize_common_flags(): ...@@ -103,6 +103,8 @@ def initialize_common_flags():
flags.DEFINE_integer( flags.DEFINE_integer(
'task_index', 0, 'task_index', 0,
'If multi-worker training, the task_index of this worker.') 'If multi-worker training, the task_index of this worker.')
flags.DEFINE_integer('save_checkpoint_freq', None,
'Number of steps to save checkpoint.')
def strategy_flags_dict(): def strategy_flags_dict():
...@@ -447,6 +449,12 @@ class DistributedExecutor(object): ...@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if save_config: if save_config:
self._save_config(model_dir) self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params params = self._params
strategy = self._strategy strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place # To reduce unnecessary send/receive input pipeline operation, we place
...@@ -540,9 +548,11 @@ class DistributedExecutor(object): ...@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps. # iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
# step of training. # step of training.
if current_step < total_steps: if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step: if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
......
...@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data') ...@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def run_executor(params, train_input_fn=None, eval_input_fn=None): def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user.""" """Runs Retinanet model on distribution strategy defined by the user."""
model_builder = model_factory.model_generator(params) model_builder = model_factory.model_generator(params)
...@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None): ...@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop=params.train.iterations_per_loop, iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps, total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(), init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True) save_config=True)
elif FLAGS.mode == 'eval': elif FLAGS.mode == 'eval':
...@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None): ...@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise ValueError('Mode not found: %s.' % FLAGS.mode) raise ValueError('Mode not found: %s.' % FLAGS.mode)
def main(argv): def run(callbacks=None):
del argv # Unused.
params = config_factory.config_generator(FLAGS.model) params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
...@@ -171,7 +173,16 @@ def main(argv): ...@@ -171,7 +173,16 @@ def main(argv):
batch_size=params.eval.batch_size, batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples) num_examples=params.eval.eval_samples)
return run_executor( return run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn) params,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)
def main(argv):
del argv # Unused.
return run()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment