Commit 4972c084 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 279361336
parent 426b2c6e
...@@ -243,6 +243,25 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -243,6 +243,25 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
else: else:
self._run_and_report_benchmark() self._run_and_report_benchmark()
@flagsaver.flagsaver
def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU."""
self.num_gpus = 1
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['batch_size'] = 8
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
FLAGS.strategy_type = 'one_device_gpu'
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
self._run_and_report_benchmark()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -781,7 +781,7 @@ class ExecutorBuilder(object): ...@@ -781,7 +781,7 @@ class ExecutorBuilder(object):
"""Builds tf.distribute.Strategy instance. """Builds tf.distribute.Strategy instance.
Args: Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. strategy_type: string. One of 'tpu', 'one_device_gpu', 'mirrored', 'multi_worker_mirrored'.
Returns: Returns:
An tf.distribute.Strategy object. Returns None if strategy_type is None. An tf.distribute.Strategy object. Returns None if strategy_type is None.
...@@ -791,6 +791,8 @@ class ExecutorBuilder(object): ...@@ -791,6 +791,8 @@ class ExecutorBuilder(object):
if strategy_type == 'tpu': if strategy_type == 'tpu':
return self._build_tpu_strategy() return self._build_tpu_strategy()
elif strategy_type == 'one_device_gpu':
return tf.distribute.OneDeviceStrategy("device:GPU:0")
elif strategy_type == 'mirrored': elif strategy_type == 'mirrored':
return self._build_mirrored_strategy() return self._build_mirrored_strategy()
elif strategy_type == 'multi_worker_mirrored': elif strategy_type == 'multi_worker_mirrored':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment