Commit cc7495e4 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328970352
parent 184c5586
......@@ -364,9 +364,9 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Bert pretrain distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=None, default_flags=None):
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertPretrainMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, default_flags=default_flags)
output_dir=output_dir, tpu=tpu, **kwargs)
def _specify_gpu_mwms_flags(self):
FLAGS.distribution_strategy = 'multi_worker_mirrored'
......
......@@ -56,8 +56,9 @@ FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadBenchmarkBase, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
......@@ -140,7 +141,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
"""
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
super(BertSquadBenchmarkReal, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......@@ -351,7 +353,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
"""
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
super(BertSquadAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......@@ -446,7 +449,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadMultiWorkerAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu)
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......@@ -518,7 +521,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu)
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment