"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "dbfd392f0521faf1df2728a50475072997dcfcdf"
Commit bcce419a authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Inject enable_runtime_flags into benchmarks.

This will help general debugging by enabling custom execution with  --benchmark_method_steps.

E.g --benchmark_method_steps=train_steps=7 will run the benchmark for only 7 steps without modifying benchmark code.

PiperOrigin-RevId: 282396875
parent 9c50e961
......@@ -35,6 +35,7 @@ from official.nlp import bert_modeling as modeling
from official.nlp.bert import input_pipeline
from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
......@@ -130,6 +131,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
self.num_steps_per_epoch = 110
self.num_epochs = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
......@@ -308,6 +310,7 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.84,
......
......@@ -32,6 +32,8 @@ from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
......@@ -132,6 +134,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
......@@ -341,6 +344,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
......
......@@ -23,6 +23,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_cifar_main
MIN_TOP_1_ACCURACY = 0.929
......@@ -197,6 +198,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
......@@ -222,6 +224,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods,
default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
......
......@@ -22,6 +22,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_imagenet_main
MIN_TOP_1_ACCURACY = 0.76
......@@ -171,6 +172,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY):
......@@ -201,6 +203,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods,
default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, skip_steps=None):
start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS)
......@@ -307,7 +310,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
FLAGS.batch_size = 96 # BatchNorm is less efficient in legacy graph mode
# due to its reliance on v1 cond.
# due to its reliance on v1 cond.
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
......@@ -863,6 +866,7 @@ class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
super(Resnet50KerasBenchmarkRemoteData, self).__init__(
output_dir=output_dir, default_flags=def_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
# skip the first epoch for performance measurement.
super(Resnet50KerasBenchmarkRemoteData,
......@@ -891,6 +895,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
flag_methods=flag_methods,
default_flags=def_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS)
......@@ -1023,6 +1028,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY):
......
......@@ -25,6 +25,7 @@ import tensorflow as tf
from official.vision.image_classification import common
from official.vision.image_classification import resnet_ctl_imagenet_main
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
from official.utils.testing import benchmark_wrappers
from official.utils.flags import core as flags_core
MIN_TOP_1_ACCURACY = 0.76
......@@ -169,6 +170,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
......@@ -197,6 +199,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
flag_methods=flag_methods,
default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(FLAGS)
......
......@@ -32,6 +32,7 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.vision.detection import main as detection
TMP_DIR = os.getenv('TMPDIR')
......@@ -151,6 +152,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
"""Starts RetinaNet accuracy benchmark test."""
......
......@@ -31,6 +31,8 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad
from official.utils.testing import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
......@@ -76,6 +78,7 @@ class XLNetClassifyAccuracy(XLNetBenchmarkBase):
super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.95,
......@@ -149,6 +152,7 @@ class XLNetSquadAccuracy(XLNetBenchmarkBase):
super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=87.0,
......
......@@ -28,6 +28,7 @@ import tensorflow as tf
from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main
from official.utils.flags import core
from official.utils.testing import benchmark_wrappers
FLAGS = flags.FLAGS
NCF_DATA_DIR_NAME = 'movielens_data'
......@@ -59,6 +60,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
else:
flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
start_time_sec = time.time()
stats = ncf_keras_main.run_ncf(FLAGS)
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main
from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
......@@ -71,6 +72,7 @@ class TransformerBenchmark(PerfZeroBenchmark):
default_flags=default_flags,
flag_methods=flag_methods)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
bleu_max=None,
bleu_min=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment