"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "7af4a1c1f8f53d8652dc594586f070bf5181ab94"
Commit 05feb2be authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307873929
parent 0e4029f0
......@@ -67,8 +67,8 @@ class BertBenchmarkBase(PerfZeroBenchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
def __init__(self, output_dir=None, tpu=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
self.num_gpus = 8
self.timer_callback = None
......
......@@ -55,8 +55,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir)
self.tpu = tpu
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
......@@ -80,9 +79,9 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
if FLAGS.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
distribution_strategy='tpu', tpu_address=FLAGS.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
......@@ -387,7 +386,13 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
FLAGS.predict_batch_size = 48
FLAGS.mode = 'train'
FLAGS.learning_rate = 8e-5
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
FLAGS.do_lower_case = True
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
self._run_and_report_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment