"src/transform/vscode:/vscode.git/clone" did not exist on "bc2d5632b753b79564ad1194af4ed3d659e6446e"
Commit 68b697a1 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Add in per GPU batch size for detection models benchmarks.

PiperOrigin-RevId: 333135907
parent 39a5fe56
......@@ -126,8 +126,9 @@ class DetectionAccuracy(DetectionBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, model, **kwargs):
def __init__(self, model, per_gpu_batch_size=8, **kwargs):
self.model = model
self.per_gpu_batch_size = per_gpu_batch_size
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags
......@@ -219,6 +220,7 @@ class DetectionBenchmarkReal(DetectionAccuracy):
params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['total_steps'] = 1875 # One epoch.
params['train']['batch_size'] = 8 * self.per_gpu_batch_size
# The iterations_per_loop must be one, otherwise the number of examples per
# second would be wrong. Currently only support calling callback per batch
# when each loop only runs on one batch, i.e. host loop for one step. The
......@@ -238,7 +240,7 @@ class DetectionBenchmarkReal(DetectionAccuracy):
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 8
params['train']['batch_size'] = 1 * self.per_gpu_batch_size
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
......@@ -253,7 +255,7 @@ class DetectionBenchmarkReal(DetectionAccuracy):
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 8
params['train']['batch_size'] = 1 * self.per_gpu_batch_size
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
......@@ -335,21 +337,27 @@ class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(model='retinanet', **kwargs)
super(RetinanetBenchmarkReal, self).__init__(model='retinanet',
per_gpu_batch_size=8,
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(model='mask_rcnn', **kwargs)
super(MaskRCNNBenchmarkReal, self).__init__(model='mask_rcnn',
per_gpu_batch_size=4,
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(model='shapemask', **kwargs)
super(ShapeMaskBenchmarkReal, self).__init__(model='shapemask',
per_gpu_batch_size=4,
**kwargs)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment