Commit 94833324 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 278756245
parent 02cc984e
......@@ -134,7 +134,10 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def _run_detection_main(self):
"""Starts detection job."""
return detection.run(callbacks=[self.timer_callback])
if self.timer_callback:
return detection.run(callbacks=[self.timer_callback])
else:
return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase):
......@@ -180,6 +183,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
'iterations_per_loop': 100,
'total_steps': 22500,
'train_file_pattern': self.train_data_path,
'checkpoint': {
'path': self.resnet_checkpoint_path,
'prefix': 'resnet50/'
},
},
'eval': {
'batch_size': 8,
......
......@@ -426,14 +426,16 @@ class DistributedExecutor(object):
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
if callback:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
if callback:
callback.on_batch_end(batch)
if save_config:
self._save_config(model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment