Commit fad183f8 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 313475975
parent 5b8d66b2
...@@ -176,6 +176,9 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -176,6 +176,9 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
def _params(self): def _params(self):
return { return {
'architecture': {
'use_bfloat16': True,
},
'train': { 'train': {
'batch_size': 64, 'batch_size': 64,
'iterations_per_loop': 100, 'iterations_per_loop': 100,
...@@ -225,6 +228,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -225,6 +228,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Run RetinaNet model accuracy test with 8 GPUs.""" """Run RetinaNet model accuracy test with 8 GPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['total_steps'] = 1875 # One epoch. params['train']['total_steps'] = 1875 # One epoch.
# The iterations_per_loop must be one, otherwise the number of examples per # The iterations_per_loop must be one, otherwise the number of examples per
# second would be wrong. Currently only support calling callback per batch # second would be wrong. Currently only support calling callback per batch
...@@ -244,6 +248,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -244,6 +248,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Run RetinaNet model accuracy test with 1 GPU.""" """Run RetinaNet model accuracy test with 1 GPU."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 8 params['train']['batch_size'] = 8
params['train']['total_steps'] = 200 params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1 params['train']['iterations_per_loop'] = 1
...@@ -258,6 +263,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -258,6 +263,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled.""" """Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 8 params['train']['batch_size'] = 8
params['train']['total_steps'] = 200 params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1 params['train']['iterations_per_loop'] = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment