Commit 2e96fbee authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323016002
parent 2659ca30
......@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
class BenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes."""
def __init__(self, **kwargs):
super(DetectionBenchmarkBase, self).__init__(**kwargs)
super(BenchmarkBase, self).__init__(**kwargs)
self.timer_callback = None
def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
......@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase):
class DetectionBenchmarkBase(BenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, **kwargs):
......@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(**kwargs)
super(DetectionBenchmarkBase, self).__init__(**kwargs)
def _run_detection_main(self):
"""Starts detection job."""
......@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase):
class DetectionAccuracy(DetectionBenchmarkBase):
"""Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming
......@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, model, **kwargs):
self.model = model
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
params,
......@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap=0.35,
do_eval=True,
warmup=1):
"""Starts RetinaNet accuracy benchmark test."""
"""Starts Detection accuracy benchmark test."""
FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance
self.timer_callback = keras_utils.TimeHistory(
......@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap, warmup)
def _setup(self):
super(RetinanetAccuracy, self)._setup()
FLAGS.model = 'retinanet'
super(DetectionAccuracy, self)._setup()
FLAGS.model = self.model
def _params(self):
return {
......@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
self._run_and_report_benchmark(params)
class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Short benchmark performance tests for RetinaNet model.
class DetectionBenchmarkReal(DetectionAccuracy):
"""Short benchmark performance tests for a detection model.
Tests RetinaNet performance in different GPU configurations.
Tests detection performance in different accelerator configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def _setup(self):
super(RetinanetBenchmarkReal, self)._setup()
super(DetectionBenchmarkReal, self)._setup()
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
@flagsaver.flagsaver
def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs."""
"""Run detection model accuracy test with 8 GPUs."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU."""
"""Run detection model accuracy test with 1 GPU."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
"""Run detection model accuracy test with 1 GPU and XLA enabled."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
......@@ -273,7 +277,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
......@@ -285,7 +289,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
......@@ -311,7 +315,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
"""Run detection model with SpineNet backbone accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
......@@ -327,5 +331,32 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(
model='retinanet',
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(
model='mask_rcnn',
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(
model='shapemask',
**kwargs)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment