Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6bbc45dd
Commit
6bbc45dd
authored
Sep 30, 2019
by
David Chen
Committed by
A. Unique TensorFlower
Sep 30, 2019
Browse files
Internal change
PiperOrigin-RevId: 272121528
parent
6d6ab9ca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
official/benchmark/bert_benchmark.py
official/benchmark/bert_benchmark.py
+3
-2
No files found.
official/benchmark/bert_benchmark.py
View file @
6bbc45dd
...
@@ -42,6 +42,7 @@ CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrp
...
@@ -42,6 +42,7 @@ CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrp
MODEL_CONFIG_FILE_PATH
=
'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
MODEL_CONFIG_FILE_PATH
=
'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long
# pylint: enable=line-too-long
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -98,7 +99,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
...
@@ -98,7 +99,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
**
kwargs
):
super
(
BertClassifyBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
)
super
(
BertClassifyBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
)
self
.
train_data_path
=
CLASSIFIER_TRAIN_DATA_PATH
self
.
train_data_path
=
CLASSIFIER_TRAIN_DATA_PATH
...
@@ -273,7 +274,7 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
...
@@ -273,7 +274,7 @@ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
**
kwargs
):
self
.
train_data_path
=
CLASSIFIER_TRAIN_DATA_PATH
self
.
train_data_path
=
CLASSIFIER_TRAIN_DATA_PATH
self
.
eval_data_path
=
CLASSIFIER_EVAL_DATA_PATH
self
.
eval_data_path
=
CLASSIFIER_EVAL_DATA_PATH
self
.
bert_config_file
=
MODEL_CONFIG_FILE_PATH
self
.
bert_config_file
=
MODEL_CONFIG_FILE_PATH
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment