"vscode:/vscode.git/clone" did not exist on "a51a30082768c566a579c661aaa7229082f6ff49"
Commit bcce419a authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Inject enable_runtime_flags into benchmarks.

This will help general debugging by enabling custom execution with  --benchmark_method_steps.

E.g --benchmark_method_steps=train_steps=7 will run the benchmark for only 7 steps without modifying benchmark code.

PiperOrigin-RevId: 282396875
parent 9c50e961
...@@ -35,6 +35,7 @@ from official.nlp import bert_modeling as modeling ...@@ -35,6 +35,7 @@ from official.nlp import bert_modeling as modeling
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import run_classifier from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
...@@ -130,6 +131,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -130,6 +131,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
self.num_steps_per_epoch = 110 self.num_steps_per_epoch = 110
self.num_epochs = 1 self.num_epochs = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
training_summary_path, training_summary_path,
min_accuracy=0, min_accuracy=0,
...@@ -308,6 +310,7 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase): ...@@ -308,6 +310,7 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
super(BertClassifyAccuracy, self).__init__(output_dir=output_dir) super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
training_summary_path, training_summary_path,
min_accuracy=0.84, min_accuracy=0.84,
......
...@@ -32,6 +32,8 @@ from official.benchmark import bert_benchmark_utils as benchmark_utils ...@@ -32,6 +32,8 @@ from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1 from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
...@@ -132,6 +134,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -132,6 +134,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
use_ds=True, use_ds=True,
run_eagerly=False): run_eagerly=False):
...@@ -341,6 +344,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -341,6 +344,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.num_train_epochs = 2 FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
use_ds=True, use_ds=True,
run_eagerly=False): run_eagerly=False):
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark import keras_benchmark from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_cifar_main from official.vision.image_classification import resnet_cifar_main
MIN_TOP_1_ACCURACY = 0.929 MIN_TOP_1_ACCURACY = 0.929
...@@ -197,6 +198,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -197,6 +198,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS) stats = resnet_cifar_main.run(FLAGS)
...@@ -222,6 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -222,6 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS) stats = resnet_cifar_main.run(FLAGS)
......
...@@ -22,6 +22,7 @@ from absl import flags ...@@ -22,6 +22,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark import keras_benchmark from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_imagenet_main from official.vision.image_classification import resnet_imagenet_main
MIN_TOP_1_ACCURACY = 0.76 MIN_TOP_1_ACCURACY = 0.76
...@@ -171,6 +172,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -171,6 +172,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.use_tensor_lr = True FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736) self._run_and_report_benchmark(top_1_min=0.736)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY, top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY): top_1_max=MAX_TOP_1_ACCURACY):
...@@ -201,6 +203,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -201,6 +203,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, skip_steps=None): def _run_and_report_benchmark(self, skip_steps=None):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS) stats = resnet_imagenet_main.run(FLAGS)
...@@ -863,6 +866,7 @@ class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase): ...@@ -863,6 +866,7 @@ class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
super(Resnet50KerasBenchmarkRemoteData, self).__init__( super(Resnet50KerasBenchmarkRemoteData, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
# skip the first epoch for performance measurement. # skip the first epoch for performance measurement.
super(Resnet50KerasBenchmarkRemoteData, super(Resnet50KerasBenchmarkRemoteData,
...@@ -891,6 +895,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -891,6 +895,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=def_flags) default_flags=def_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS) stats = resnet_imagenet_main.run(FLAGS)
...@@ -1023,6 +1028,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -1023,6 +1028,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY, top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY): top_1_max=MAX_TOP_1_ACCURACY):
......
...@@ -25,6 +25,7 @@ import tensorflow as tf ...@@ -25,6 +25,7 @@ import tensorflow as tf
from official.vision.image_classification import common from official.vision.image_classification import common
from official.vision.image_classification import resnet_ctl_imagenet_main from official.vision.image_classification import resnet_ctl_imagenet_main
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
from official.utils.testing import benchmark_wrappers
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
MIN_TOP_1_ACCURACY = 0.76 MIN_TOP_1_ACCURACY = 0.76
...@@ -169,6 +170,7 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -169,6 +170,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(flags.FLAGS) stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
...@@ -197,6 +199,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -197,6 +199,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(FLAGS) stats = resnet_ctl_imagenet_main.run(FLAGS)
......
...@@ -32,6 +32,7 @@ import tensorflow as tf ...@@ -32,6 +32,7 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.vision.detection import main as detection from official.vision.detection import main as detection
TMP_DIR = os.getenv('TMPDIR') TMP_DIR = os.getenv('TMPDIR')
...@@ -151,6 +152,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -151,6 +152,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
def __init__(self, output_dir=TMP_DIR, **kwargs): def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetAccuracy, self).__init__(output_dir=output_dir) super(RetinanetAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35): def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
"""Starts RetinaNet accuracy benchmark test.""" """Starts RetinaNet accuracy benchmark test."""
......
...@@ -31,6 +31,8 @@ import tensorflow as tf ...@@ -31,6 +31,8 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad from official.nlp.xlnet import run_squad
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
...@@ -76,6 +78,7 @@ class XLNetClassifyAccuracy(XLNetBenchmarkBase): ...@@ -76,6 +78,7 @@ class XLNetClassifyAccuracy(XLNetBenchmarkBase):
super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir) super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
training_summary_path, training_summary_path,
min_accuracy=0.95, min_accuracy=0.95,
...@@ -149,6 +152,7 @@ class XLNetSquadAccuracy(XLNetBenchmarkBase): ...@@ -149,6 +152,7 @@ class XLNetSquadAccuracy(XLNetBenchmarkBase):
super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir) super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
training_summary_path, training_summary_path,
min_accuracy=87.0, min_accuracy=87.0,
......
...@@ -28,6 +28,7 @@ import tensorflow as tf ...@@ -28,6 +28,7 @@ import tensorflow as tf
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main from official.recommendation import ncf_keras_main
from official.utils.flags import core from official.utils.flags import core
from official.utils.testing import benchmark_wrappers
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
NCF_DATA_DIR_NAME = 'movielens_data' NCF_DATA_DIR_NAME = 'movielens_data'
...@@ -59,6 +60,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark): ...@@ -59,6 +60,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
else: else:
flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags) flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0): def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
start_time_sec = time.time() start_time_sec = time.time()
stats = ncf_keras_main.run_ncf(FLAGS) stats = ncf_keras_main.run_ncf(FLAGS)
......
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main from official.transformer.v2 import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
...@@ -71,6 +72,7 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -71,6 +72,7 @@ class TransformerBenchmark(PerfZeroBenchmark):
default_flags=default_flags, default_flags=default_flags,
flag_methods=flag_methods) flag_methods=flag_methods)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
bleu_max=None, bleu_max=None,
bleu_min=None, bleu_min=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment