Commit 06412123 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 275103426
parent 1b77cd80
......@@ -25,14 +25,16 @@ import os
import time
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core
from official.vision.detection import main as detection
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
# pylint: disable=line-too-long
......@@ -143,7 +145,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs):
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetAccuracy, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
......@@ -208,7 +210,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
`benchmark_(number of gpus)_gpu` format.
"""
def __init__(self, output_dir=None, **kwargs):
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(output_dir=output_dir)
@flagsaver.flagsaver
......@@ -216,7 +218,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Run RetinaNet model accuracy test with 8 GPUs."""
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['total_steps'] = 1875 # One epoch.
params['train']['total_steps'] = 1875 # One epoch.
# The iterations_per_loop must be one, otherwise the number of examples per
# second would be wrong. Currently only support calling callback per batch
# when each loop only runs on one batch, i.e. host loop for one step. The
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment