Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cc7495e4
Commit
cc7495e4
authored
Aug 28, 2020
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Aug 28, 2020
Browse files
Internal change
PiperOrigin-RevId: 328970352
parent
184c5586
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
official/benchmark/bert_pretrain_benchmark.py
official/benchmark/bert_pretrain_benchmark.py
+2
-2
official/benchmark/bert_squad_benchmark.py
official/benchmark/bert_squad_benchmark.py
+9
-6
No files found.
official/benchmark/bert_pretrain_benchmark.py
View file @
cc7495e4
...
@@ -364,9 +364,9 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
...
@@ -364,9 +364,9 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
class
BertPretrainMultiWorkerBenchmark
(
BertPretrainAccuracyBenchmark
):
class
BertPretrainMultiWorkerBenchmark
(
BertPretrainAccuracyBenchmark
):
"""Bert pretrain distributed benchmark tests with multiple workers."""
"""Bert pretrain distributed benchmark tests with multiple workers."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
super
(
BertPretrainMultiWorkerBenchmark
,
self
).
__init__
(
super
(
BertPretrainMultiWorkerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
default_fla
gs
)
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwar
gs
)
def
_specify_gpu_mwms_flags
(
self
):
def
_specify_gpu_mwms_flags
(
self
):
FLAGS
.
distribution_strategy
=
'multi_worker_mirrored'
FLAGS
.
distribution_strategy
=
'multi_worker_mirrored'
...
...
official/benchmark/bert_squad_benchmark.py
View file @
cc7495e4
...
@@ -56,8 +56,9 @@ FLAGS = flags.FLAGS
...
@@ -56,8 +56,9 @@ FLAGS = flags.FLAGS
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
class
BertSquadBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
"""Base class to hold methods common to test classes in the module."""
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
super
(
BertSquadBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwargs
)
def
_read_training_summary_from_file
(
self
):
def
_read_training_summary_from_file
(
self
):
"""Reads the training summary from a file."""
"""Reads the training summary from a file."""
...
@@ -140,7 +141,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
...
@@ -140,7 +141,8 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
"""
"""
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
tpu
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
super
(
BertSquadBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwargs
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
@@ -351,7 +353,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
...
@@ -351,7 +353,8 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
"""
"""
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
super
(
BertSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwargs
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
@@ -446,7 +449,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
...
@@ -446,7 +449,7 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
None
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadMultiWorkerAccuracy
,
self
).
__init__
(
super
(
BertSquadMultiWorkerAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwargs
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
@@ -518,7 +521,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
...
@@ -518,7 +521,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
tpu
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
tpu
=
None
,
**
kwargs
):
super
(
BertSquadMultiWorkerBenchmark
,
self
).
__init__
(
super
(
BertSquadMultiWorkerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
tpu
=
tpu
)
output_dir
=
output_dir
,
tpu
=
tpu
,
**
kwargs
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up the benchmark and SQuAD flags."""
"""Sets up the benchmark and SQuAD flags."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment