Commit 2e96fbee authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323016002
parent 2659ca30
...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi ...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long # pylint: enable=line-too-long
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): class BenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetectionBenchmarkBase, self).__init__(**kwargs) super(BenchmarkBase, self).__init__(**kwargs)
self.timer_callback = None self.timer_callback = None
def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap, def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): ...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
extras={'flags': flags_str}) extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase): class DetectionBenchmarkBase(BenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
self.eval_data_path = COCO_EVAL_DATA self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(**kwargs) super(DetectionBenchmarkBase, self).__init__(**kwargs)
def _run_detection_main(self): def _run_detection_main(self):
"""Starts detection job.""" """Starts detection job."""
...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
return detection.run() return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase): class DetectionAccuracy(DetectionBenchmarkBase):
"""Accuracy test for RetinaNet model. """Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming Tests RetinaNet detection task model accuracy. The naming
...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format. `benchmark_(number of gpus)_gpu_(dataset type)` format.
""" """
def __init__(self, model, **kwargs):
self.model = model
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
params, params,
...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap=0.35, max_ap=0.35,
do_eval=True, do_eval=True,
warmup=1): warmup=1):
"""Starts RetinaNet accuracy benchmark test.""" """Starts Detection accuracy benchmark test."""
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance # Need timer callback to measure performance
self.timer_callback = keras_utils.TimeHistory( self.timer_callback = keras_utils.TimeHistory(
...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap, warmup) max_ap, warmup)
def _setup(self): def _setup(self):
super(RetinanetAccuracy, self)._setup() super(DetectionAccuracy, self)._setup()
FLAGS.model = 'retinanet' FLAGS.model = self.model
def _params(self): def _params(self):
return { return {
...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
self._run_and_report_benchmark(params) self._run_and_report_benchmark(params)
class RetinanetBenchmarkReal(RetinanetAccuracy): class DetectionBenchmarkReal(DetectionAccuracy):
"""Short benchmark performance tests for RetinaNet model. """Short benchmark performance tests for a detection model.
Tests RetinaNet performance in different GPU configurations. Tests detection performance in different accelerator configurations.
The naming convention of below test cases follow The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format. `benchmark_(number of gpus)_gpu` format.
""" """
def _setup(self): def _setup(self):
super(RetinanetBenchmarkReal, self)._setup() super(DetectionBenchmarkReal, self)._setup()
# Use negative value to avoid saving checkpoints. # Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1 FLAGS.save_checkpoint_freq = -1
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_8_gpu_coco(self): def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs.""" """Run detection model accuracy test with 8 GPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_1_gpu_coco(self): def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU.""" """Run detection model accuracy test with 1 GPU."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self): def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled.""" """Run detection model accuracy test with 1 GPU and XLA enabled."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_coco(self): def benchmark_2x2_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 64 params['train']['batch_size'] = 64
...@@ -273,7 +277,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -273,7 +277,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self): def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 256 params['train']['batch_size'] = 256
...@@ -285,7 +289,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -285,7 +289,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self): def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 64 params['train']['batch_size'] = 64
...@@ -311,7 +315,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -311,7 +315,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self): def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs.""" """Run detection model with SpineNet backbone accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['backbone'] = 'spinenet' params['architecture']['backbone'] = 'spinenet'
...@@ -327,5 +331,32 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -327,5 +331,32 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(
model='retinanet',
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(
model='mask_rcnn',
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(
model='shapemask',
**kwargs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment