"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1ade42f72998ec47147cb35e10f5d2283737e420"
Commit 05feb2be authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307873929
parent 0e4029f0
...@@ -67,8 +67,8 @@ class BertBenchmarkBase(PerfZeroBenchmark): ...@@ -67,8 +67,8 @@ class BertBenchmarkBase(PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
local_flags = None local_flags = None
def __init__(self, output_dir=None): def __init__(self, output_dir=None, tpu=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir) super(BertBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
self.num_gpus = 8 self.num_gpus = 8
self.timer_callback = None self.timer_callback = None
......
...@@ -55,8 +55,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -55,8 +55,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None): def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir) super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
self.tpu = tpu
def _read_training_summary_from_file(self): def _read_training_summary_from_file(self):
"""Reads the training summary from a file.""" """Reads the training summary from a file."""
...@@ -80,9 +79,9 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -80,9 +79,9 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
Returns: Returns:
A `tf.distribute.DistibutionStrategy` object. A `tf.distribute.DistibutionStrategy` object.
""" """
if self.tpu or ds_type == 'tpu': if FLAGS.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy( return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu) distribution_strategy='tpu', tpu_address=FLAGS.tpu)
elif ds_type == 'multi_worker_mirrored': elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
...@@ -387,7 +386,13 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -387,7 +386,13 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48 FLAGS.train_batch_size = 48
FLAGS.predict_batch_size = 48
FLAGS.mode = 'train'
FLAGS.learning_rate = 8e-5
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
FLAGS.do_lower_case = True
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
self._run_and_report_benchmark() self._run_and_report_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment