"...resnet50_tensorflow.git" did not exist on "36d6f41f7584546100a002ad780a768dfc4569cf"
Commit cc7495e4 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328970352
parent 184c5586
...@@ -364,9 +364,9 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -364,9 +364,9 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark): class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Bert pretrain distributed benchmark tests with multiple workers.""" """Bert pretrain distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertPretrainMultiWorkerBenchmark, self).__init__( super(BertPretrainMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, default_flags=default_flags) output_dir=output_dir, tpu=tpu, **kwargs)
def _specify_gpu_mwms_flags(self): def _specify_gpu_mwms_flags(self):
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
......
...@@ -56,8 +56,9 @@ FLAGS = flags.FLAGS ...@@ -56,8 +56,9 @@ FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None): def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu) super(BertSquadBenchmarkBase, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _read_training_summary_from_file(self): def _read_training_summary_from_file(self):
"""Reads the training summary from a file.""" """Reads the training summary from a file."""
...@@ -140,7 +141,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -140,7 +141,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
""" """
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs): def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu) super(BertSquadBenchmarkReal, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
...@@ -351,7 +353,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -351,7 +353,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
""" """
def __init__(self, output_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu) super(BertSquadAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
...@@ -446,7 +449,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase): ...@@ -446,7 +449,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
def __init__(self, output_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadMultiWorkerAccuracy, self).__init__( super(BertSquadMultiWorkerAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu) output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
...@@ -518,7 +521,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -518,7 +521,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs): def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadMultiWorkerBenchmark, self).__init__( super(BertSquadMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu) output_dir=output_dir, tpu=tpu, **kwargs)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment