Commit 19c9252a authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310180583
parent 5a55f69d
......@@ -74,9 +74,9 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
warmup_steps = int(epochs * steps_per_epoch * 0.1)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if self.default_flags['tpu']:
if self.tpu:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.default_flags['tpu'])
distribution_strategy='tpu', tpu_address=self.tpu)
else:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
......
......@@ -68,7 +68,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool):
"""Runs and reports the benchmark given the provided configuration."""
distribution = distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.default_flags['tpu'])
distribution_strategy='tpu', tpu_address=self.tpu)
logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
start_time_sec = time.time()
run_pretraining.run_bert_pretrain(
......
......@@ -80,9 +80,9 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.default_flags['tpu'] or ds_type == 'tpu':
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.default_flags['tpu'])
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
......
......@@ -72,11 +72,13 @@ class PerfZeroBenchmark(tf.test.Benchmark):
# TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = resolved_tpu
else:
self.default_flags['tpu'] = ''
logging.info('root_data_dir: %s', root_data_dir)
@property
def tpu(self):
return self.default_flags.get('tpu', None)
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment