Commit 6a50c338 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 283648734
parent 1d3cd3cf
...@@ -265,5 +265,26 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -265,5 +265,26 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
else: else:
self._run_and_report_benchmark() self._run_and_report_benchmark()
@flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
self.num_gpus = 1
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['batch_size'] = 8
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
FLAGS.strategy_type = 'one_device_gpu'
FLAGS.enable_xla = True
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
self._run_and_report_benchmark()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -35,9 +35,15 @@ from official.vision.detection.dataloader import input_reader ...@@ -35,9 +35,15 @@ from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory from official.vision.detection.modeling import factory as model_factory
from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
flags.DEFINE_bool(
'enable_xla',
default=False,
help='Enable XLA for GPU')
flags.DEFINE_string( flags.DEFINE_string(
'mode', 'mode',
default='train', default='train',
...@@ -166,6 +172,8 @@ def run_executor(params, ...@@ -166,6 +172,8 @@ def run_executor(params,
def run(callbacks=None): def run(callbacks=None):
keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
params = config_factory.config_generator(FLAGS.model) params = config_factory.config_generator(FLAGS.model)
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment