Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
965cc3ee
Unverified
Commit
965cc3ee
authored
Apr 21, 2020
by
Ayushman Kumar
Committed by
GitHub
Apr 21, 2020
Browse files
Merge pull request #7 from tensorflow/master
updated
parents
1f3247f4
1f685c54
Changes
222
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
910 additions
and
214 deletions
+910
-214
official/benchmark/models/synthetic_util.py
official/benchmark/models/synthetic_util.py
+129
-0
official/benchmark/ncf_keras_benchmark.py
official/benchmark/ncf_keras_benchmark.py
+1
-3
official/benchmark/owner_utils.py
official/benchmark/owner_utils.py
+67
-0
official/benchmark/owner_utils_test.py
official/benchmark/owner_utils_test.py
+104
-0
official/benchmark/perfzero_benchmark.py
official/benchmark/perfzero_benchmark.py
+0
-0
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+18
-14
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+1
-1
official/benchmark/shakespeare_benchmark.py
official/benchmark/shakespeare_benchmark.py
+3
-3
official/benchmark/squad_evaluate_v1_1.py
official/benchmark/squad_evaluate_v1_1.py
+0
-109
official/benchmark/tfhub_memory_usage_benchmark.py
official/benchmark/tfhub_memory_usage_benchmark.py
+2
-2
official/benchmark/transformer_benchmark.py
official/benchmark/transformer_benchmark.py
+2
-4
official/benchmark/xlnet_benchmark.py
official/benchmark/xlnet_benchmark.py
+1
-1
official/colab/bert.ipynb
official/colab/bert.ipynb
+383
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+9
-4
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+81
-26
official/nlp/README.md
official/nlp/README.md
+2
-4
official/nlp/albert/run_squad.py
official/nlp/albert/run_squad.py
+33
-3
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+57
-35
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-1
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+13
-4
No files found.
official/benchmark/models/synthetic_util.py
0 → 100644
View file @
965cc3ee
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
random
import
string
from
absl
import
logging
import
tensorflow
as
tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class
SyntheticDataset
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_random_name
(),
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
initializers
.
append
(
v
.
initializer
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
else
:
return
self
.
_initializers
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
(
self
,
dataset
):
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
else
:
return
SyntheticDataset
(
dataset
)
def
make_iterator
(
self
,
dataset
):
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
undo_set_up_synthetic_data
():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
official/benchmark/ncf_keras_benchmark.py
View file @
965cc3ee
...
...
@@ -24,11 +24,10 @@ from absl import flags
from
absl
import
logging
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
from
official.benchmark
import
benchmark_wrappers
from
official.recommendation
import
ncf_common
from
official.recommendation
import
ncf_keras_main
from
official.utils.flags
import
core
from
official.utils.testing
import
benchmark_wrappers
FLAGS
=
flags
.
FLAGS
NCF_DATA_DIR_NAME
=
'movielens_data'
...
...
@@ -50,7 +49,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
ncf_common
.
define_ncf_flags
()
...
...
official/benchmark/owner_utils.py
0 → 100644
View file @
965cc3ee
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to set Owner annotations on benchmarks.
@owner_utils.Owner('owner_team/user') can be set either at the benchmark class
level / benchmark method level or both.
Runner frameworks can use owner_utils.GetOwner(benchmark_method) to get the
actual owner. Python inheritance for the owner attribute is respected. (E.g
method level owner takes precedence over class level).
See owner_utils_test for associated tests and more examples.
The decorator can be applied both at the method level and at the class level.
Simple example:
===============
class MLBenchmark:
@Owner('example_id')
def benchmark_method_1_gpu(self):
return True
"""
def
Owner
(
owner_name
):
"""Sets the owner attribute on a decorated method or class."""
def
_Wrapper
(
func_or_class
):
"""Sets the benchmark owner attribute."""
func_or_class
.
__benchmark__owner__
=
owner_name
return
func_or_class
return
_Wrapper
def
GetOwner
(
benchmark_method_or_class
):
"""Gets the inherited owner attribute for this benchmark.
Checks for existence of __benchmark__owner__. If it's not present, looks for
it in the parent class's attribute list.
Args:
benchmark_method_or_class: A benchmark method or class.
Returns:
string - the associated owner if present / None.
"""
if
hasattr
(
benchmark_method_or_class
,
'__benchmark__owner__'
):
return
benchmark_method_or_class
.
__benchmark__owner__
elif
hasattr
(
benchmark_method_or_class
,
'__self__'
):
if
hasattr
(
benchmark_method_or_class
.
__self__
,
'__benchmark__owner__'
):
return
benchmark_method_or_class
.
__self__
.
__benchmark__owner__
return
None
official/benchmark/owner_utils_test.py
0 → 100644
View file @
965cc3ee
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.benchmark.owner_utils."""
from
absl.testing
import
absltest
from
official.benchmark
import
owner_utils
@
owner_utils
.
Owner
(
'static_owner'
)
def
static_function
(
foo
=
5
):
return
foo
def
static_function_without_owner
(
foo
=
5
):
return
foo
class
BenchmarkClassWithoutOwner
:
def
method_without_owner
(
self
):
return
100
@
owner_utils
.
Owner
(
'method_owner'
)
def
method_with_owner
(
self
):
return
200
@
owner_utils
.
Owner
(
'class_owner'
)
class
SomeBenchmarkClass
:
def
method_inherited_owner
(
self
):
return
123
@
owner_utils
.
Owner
(
'method_owner'
)
def
method_override_owner
(
self
):
return
345
@
owner_utils
.
Owner
(
'new_class_owner'
)
class
InheritedClass
(
SomeBenchmarkClass
):
def
method_inherited_owner
(
self
):
return
456
@
owner_utils
.
Owner
(
'new_method_owner'
)
def
method_override_owner
(
self
):
return
567
class
OwnerUtilsTest
(
absltest
.
TestCase
):
"""Tests to assert for owner decorator functionality."""
def
test_owner_tag_missing
(
self
):
self
.
assertEqual
(
None
,
owner_utils
.
GetOwner
(
static_function_without_owner
))
benchmark_class
=
BenchmarkClassWithoutOwner
()
self
.
assertEqual
(
None
,
owner_utils
.
GetOwner
(
benchmark_class
.
method_without_owner
))
self
.
assertEqual
(
100
,
benchmark_class
.
method_without_owner
())
self
.
assertEqual
(
'method_owner'
,
owner_utils
.
GetOwner
(
benchmark_class
.
method_with_owner
))
self
.
assertEqual
(
200
,
benchmark_class
.
method_with_owner
())
def
test_owner_attributes_static
(
self
):
self
.
assertEqual
(
'static_owner'
,
owner_utils
.
GetOwner
(
static_function
))
self
.
assertEqual
(
5
,
static_function
(
5
))
def
test_owner_attributes_per_class
(
self
):
level1
=
SomeBenchmarkClass
()
self
.
assertEqual
(
'class_owner'
,
owner_utils
.
GetOwner
(
level1
.
method_inherited_owner
))
self
.
assertEqual
(
123
,
level1
.
method_inherited_owner
())
self
.
assertEqual
(
'method_owner'
,
owner_utils
.
GetOwner
(
level1
.
method_override_owner
))
self
.
assertEqual
(
345
,
level1
.
method_override_owner
())
def
test_owner_attributes_inherited_class
(
self
):
level2
=
InheritedClass
()
self
.
assertEqual
(
'new_class_owner'
,
owner_utils
.
GetOwner
(
level2
.
method_inherited_owner
))
self
.
assertEqual
(
456
,
level2
.
method_inherited_owner
())
self
.
assertEqual
(
'new_method_owner'
,
owner_utils
.
GetOwner
(
level2
.
method_override_owner
))
self
.
assertEqual
(
567
,
level2
.
method_override_owner
())
if
__name__
==
'__main__'
:
absltest
.
main
()
official/
utils/testing
/perfzero_benchmark.py
→
official/
benchmark
/perfzero_benchmark.py
View file @
965cc3ee
File moved
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
965cc3ee
...
...
@@ -13,19 +13,19 @@
# limitations under the License.
# ==============================================================================
"""Executes CTL benchmarks and accuracy tests."""
# pylint: disable=line-too-long,g-bad-import-order
from
__future__
import
print_function
import
os
import
time
# pylint: disable=g-bad-import-order
from
absl
import
flags
import
tensorflow
as
tf
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
resnet_ctl_imagenet_main
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
import
benchmark_wrappers
from
official.utils.flags
import
core
as
flags_core
MIN_TOP_1_ACCURACY
=
0.76
...
...
@@ -53,7 +53,8 @@ class CtlBenchmark(PerfZeroBenchmark):
top_1_min
=
None
,
total_batch_size
=
None
,
log_steps
=
None
,
warmup
=
1
):
warmup
=
1
,
start_time_sec
=
None
):
"""Report benchmark results by writing to local protobuf file.
Args:
...
...
@@ -64,6 +65,7 @@ class CtlBenchmark(PerfZeroBenchmark):
total_batch_size: Global batch-size.
log_steps: How often the log was created for stats['step_timestamp_log'].
warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch.
"""
metrics
=
[]
...
...
@@ -98,6 +100,12 @@ class CtlBenchmark(PerfZeroBenchmark):
'value'
:
stats
[
'avg_exp_per_second'
]
})
if
start_time_sec
and
'step_timestamp_log'
in
stats
:
time_log
=
stats
[
'step_timestamp_log'
]
# time_log[0] is recorded at the beginning of the first step.
startup_time
=
time_log
[
0
].
timestamp
-
start_time_sec
metrics
.
append
({
'name'
:
'startup_time'
,
'value'
:
startup_time
})
flags_str
=
flags_core
.
get_nondefault_flags_as_str
()
self
.
report_benchmark
(
iters
=-
1
,
...
...
@@ -136,8 +144,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu'
)
FLAGS
.
dtype
=
'fp32'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_fp16
(
self
):
...
...
@@ -150,8 +156,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_fp16'
)
FLAGS
.
dtype
=
'fp16'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_amp
(
self
):
...
...
@@ -165,8 +169,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp'
)
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
@
benchmark_wrappers
.
enable_runtime_flags
...
...
@@ -181,7 +183,8 @@ class Resnet50CtlAccuracy(CtlBenchmark):
top_1_min
=
MIN_TOP_1_ACCURACY
,
top_1_max
=
MAX_TOP_1_ACCURACY
,
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
100
)
log_steps
=
100
,
start_time_sec
=
start_time_sec
)
def
_get_model_dir
(
self
,
folder_name
):
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
...
...
@@ -213,7 +216,8 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
wall_time_sec
,
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
warmup
=
warmup
)
warmup
=
warmup
,
start_time_sec
=
start_time_sec
)
def
benchmark_1_gpu_no_dist_strat
(
self
):
"""Test Keras model with 1 GPU, no distribution strategy."""
...
...
@@ -278,7 +282,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
batch_size
=
12
8
FLAGS
.
batch_size
=
12
0
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_while_loop
=
False
FLAGS
.
single_l2_loss_op
=
True
...
...
@@ -291,7 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_fp16_eager'
)
FLAGS
.
batch_size
=
2
5
0
FLAGS
.
batch_size
=
2
4
0
FLAGS
.
dtype
=
'fp16'
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_while_loop
=
False
...
...
official/benchmark/retinanet_benchmark.py
View file @
965cc3ee
...
...
@@ -32,7 +32,7 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.utils.flags
import
core
as
flags_core
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
from
official.vision.detection
import
main
as
detection
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
...
...
official/
staging/shakespe
ar
e
/shakespeare_benchmark.py
→
official/
benchm
ar
k
/shakespeare_benchmark.py
View file @
965cc3ee
...
...
@@ -23,11 +23,11 @@ import time
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.
staging
.shakespeare
import
shakespeare_main
from
official.
benchmark.models
.shakespeare
import
shakespeare_main
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
from
official.
utils.testing
import
benchmark_wrappers
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
import
benchmark_wrappers
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA
=
'shakespeare/shakespeare.txt'
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
...
...
official/benchmark/squad_evaluate_v1_1.py
deleted
100644 → 0
View file @
1f3247f4
# Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and
# Percy Liang. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
string
# pylint: disable=g-bad-import-order
from
absl
import
logging
# pylint: enable=g-bad-import-order
def
_normalize_answer
(
s
):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r
"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_f1_score
(
prediction
,
ground_truth
):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens
=
_normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
_normalize_answer
(
ground_truth
).
split
()
prediction_counter
=
collections
.
Counter
(
prediction_tokens
)
ground_truth_counter
=
collections
.
Counter
(
ground_truth_tokens
)
common
=
prediction_counter
&
ground_truth_counter
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_exact_match_score
(
prediction
,
ground_truth
):
"""Checks if predicted answer exactly matches ground truth answer."""
return
_normalize_answer
(
prediction
)
==
_normalize_answer
(
ground_truth
)
def
_metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Computes the max over all metric scores."""
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
evaluate
(
dataset
,
predictions
):
"""Evaluates predictions for a dataset."""
f1
=
exact_match
=
total
=
0
for
article
in
dataset
:
for
paragraph
in
article
[
"paragraphs"
]:
for
qa
in
paragraph
[
"qas"
]:
total
+=
1
if
qa
[
"id"
]
not
in
predictions
:
message
=
"Unanswered question "
+
qa
[
"id"
]
+
" will receive score 0."
logging
.
error
(
message
)
continue
ground_truths
=
[
entry
[
"text"
]
for
entry
in
qa
[
"answers"
]]
prediction
=
predictions
[
qa
[
"id"
]]
exact_match
+=
_metric_max_over_ground_truths
(
_exact_match_score
,
prediction
,
ground_truths
)
f1
+=
_metric_max_over_ground_truths
(
_f1_score
,
prediction
,
ground_truths
)
exact_match
=
exact_match
/
total
f1
=
f1
/
total
return
{
"exact_match"
:
exact_match
,
"f1"
:
f1
}
official/benchmark/tfhub_memory_usage_benchmark.py
View file @
965cc3ee
...
...
@@ -20,10 +20,10 @@ import functools
import
time
from
absl
import
flags
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
FLAGS
=
flags
.
FLAGS
...
...
official/benchmark/transformer_benchmark.py
View file @
965cc3ee
...
...
@@ -22,12 +22,11 @@ import time
from
absl
import
flags
import
tensorflow
as
tf
from
official.benchmark
import
benchmark_wrappers
from
official.benchmark.perfzero_benchmark
import
PerfZeroBenchmark
from
official.nlp.transformer
import
misc
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.utils.flags
import
core
as
flags_core
from
official.utils.testing
import
benchmark_wrappers
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
...
...
@@ -44,7 +43,6 @@ class TransformerBenchmark(PerfZeroBenchmark):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
flag_methods
=
None
):
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
...
...
official/benchmark/xlnet_benchmark.py
View file @
965cc3ee
...
...
@@ -31,7 +31,7 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_squad
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
# pylint: disable=line-too-long
...
...
official/colab/bert.ipynb
0 → 100644
View file @
965cc3ee
This diff is collapsed.
Click to expand it.
official/modeling/hyperparams/base_config.py
View file @
965cc3ee
...
...
@@ -257,10 +257,8 @@ class RuntimeConfig(Config):
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
...
...
@@ -272,11 +270,13 @@ class RuntimeConfig(Config):
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
"""
distribution_strategy
:
str
=
'mirrored'
enable_eager
:
bool
=
False
enable_xla
:
bool
=
False
gpu_threads_enabled
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
...
...
@@ -286,6 +286,8 @@ class RuntimeConfig(Config):
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
loss_scale
:
Optional
[
str
]
=
None
run_eagerly
:
bool
=
False
@
dataclasses
.
dataclass
...
...
@@ -312,7 +314,10 @@ class CallbacksConfig(Config):
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export
:
bool
=
True
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
official/modeling/training/distributed_executor.py
View file @
965cc3ee
...
...
@@ -19,7 +19,6 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
json
import
os
from
absl
import
flags
...
...
@@ -31,8 +30,9 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
official.modeling.hyperparams
import
params_dict
from
official.utils.misc
import
distribution_utils
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
FLAGS
=
flags
.
FLAGS
...
...
@@ -59,6 +59,45 @@ def _no_metric():
return
None
def
metrics_as_dict
(
metric
):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be a object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metrics
=
{
metric
.
name
:
metric
}
elif
isinstance
(
metric
,
list
):
metrics
=
{
m
.
name
:
m
for
m
in
metric
}
elif
isinstance
(
metric
,
dict
):
metrics
=
metric
elif
not
metric
:
return
{}
else
:
metrics
=
{
'metric'
:
metric
}
return
metrics
def
metric_results
(
metric
):
"""Collects results from the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
metric_result
=
{
name
:
m
.
result
().
numpy
().
astype
(
float
)
for
name
,
m
in
metrics
.
items
()
}
return
metric_result
def
reset_states
(
metric
):
"""Resets states of the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
for
m
in
metrics
.
values
():
m
.
reset_states
()
class
SummaryWriter
(
object
):
"""Simple SummaryWriter for writing dictionary of metrics.
...
...
@@ -185,6 +224,7 @@ class DistributedExecutor(object):
loss_fn
,
optimizer
,
metric
=
None
):
metrics
=
metrics_as_dict
(
metric
)
def
_replicated_step
(
inputs
):
"""Replicated training step."""
...
...
@@ -195,11 +235,8 @@ class DistributedExecutor(object):
prediction_loss
=
loss_fn
(
labels
,
outputs
)
loss
=
tf
.
reduce_mean
(
prediction_loss
)
loss
=
loss
/
strategy
.
num_replicas_in_sync
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metric
.
update_state
(
labels
,
outputs
)
else
:
logging
.
error
(
'train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
for
m
in
metrics
.
values
():
m
.
update_state
(
labels
,
outputs
)
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
...
...
@@ -235,6 +272,7 @@ class DistributedExecutor(object):
Args:
iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns:
The loss tensor.
...
...
@@ -259,6 +297,7 @@ class DistributedExecutor(object):
def
_create_test_step
(
self
,
strategy
,
model
,
metric
):
"""Creates a distributed test step."""
metrics
=
metrics_as_dict
(
metric
)
@
tf
.
function
def
test_step
(
iterator
):
...
...
@@ -266,22 +305,20 @@ class DistributedExecutor(object):
if
not
metric
:
logging
.
info
(
'Skip test_step because metric is None (%s)'
,
metric
)
return
None
,
None
if
not
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
ValueError
(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'
.
format
(
metric
))
def
_test_step_fn
(
inputs
):
"""Replicated accuracy calculation."""
inputs
,
labels
=
inputs
model_outputs
=
model
(
inputs
,
training
=
False
)
metric
.
update_state
(
labels
,
model_outputs
)
for
m
in
metrics
.
values
():
m
.
update_state
(
labels
,
model_outputs
)
return
labels
,
model_outputs
return
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
return
test_step
def
train
(
self
,
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
...
...
@@ -330,10 +367,12 @@ class DistributedExecutor(object):
eval_metric_fn
=
eval_metric_fn
or
_no_metric
if
custom_callbacks
and
iterations_per_loop
!=
1
:
logging
.
error
(
logging
.
warning
(
'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)'
,
iterations_per_loop
)
custom_callbacks
=
custom_callbacks
or
[]
def
_run_callbacks_on_batch_begin
(
batch
):
"""Runs custom callbacks at the start of every step."""
if
not
custom_callbacks
:
...
...
@@ -402,6 +441,11 @@ class DistributedExecutor(object):
test_summary_writer
=
summary_writer_fn
(
model_dir
,
'eval_test'
)
self
.
eval_summary_writer
=
test_summary_writer
.
writer
# Use training summary writer in TimeHistory if it's in use
for
cb
in
custom_callbacks
:
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
):
cb
.
summary_writer
=
self
.
train_summary_writer
# Continue training loop.
train_step
=
self
.
_create_train_step
(
strategy
=
strategy
,
...
...
@@ -414,6 +458,20 @@ class DistributedExecutor(object):
self
.
global_train_step
=
model
.
optimizer
.
iterations
test_step
=
self
.
_create_test_step
(
strategy
,
model
,
metric
=
eval_metric
)
# Step-0 operations
if
current_step
==
0
and
not
latest_checkpoint_file
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
reset_states
(
eval_metric
)
logging
.
info
(
'Training started'
)
last_save_checkpoint_step
=
current_step
while
current_step
<
total_steps
:
...
...
@@ -422,23 +480,19 @@ class DistributedExecutor(object):
_run_callbacks_on_batch_begin
(
current_step
)
train_loss
=
train_step
(
train_iterator
,
tf
.
convert_to_tensor
(
num_steps
,
dtype
=
tf
.
int32
))
_run_callbacks_on_batch_end
(
current_step
)
current_step
+=
num_steps
train_loss
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_loss
)
_run_callbacks_on_batch_end
(
current_step
-
1
)
if
not
isinstance
(
train_loss
,
dict
):
train_loss
=
{
'total_loss'
:
train_loss
}
if
np
.
isnan
(
train_loss
[
'total_loss'
]):
raise
ValueError
(
'total loss is NaN.'
)
if
train_metric
:
train_metric_result
=
train_metric
.
result
()
if
isinstance
(
train_metric
,
tf
.
keras
.
metrics
.
Metric
):
train_metric_result
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_metric_result
)
if
not
isinstance
(
train_metric_result
,
dict
):
train_metric_result
=
{
'metric'
:
train_metric_result
}
train_metric_result
=
metric_results
(
train_metric
)
train_metric_result
.
update
(
train_loss
)
else
:
train_metric_result
=
train_loss
...
...
@@ -475,9 +529,9 @@ class DistributedExecutor(object):
# Re-initialize evaluation metric, except the last step.
if
eval_metric
and
current_step
<
total_steps
:
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
if
train_metric
and
current_step
<
total_steps
:
train_metric
.
reset_states
()
reset_states
(
train_metric
)
# Reaches the end of training and saves the last checkpoint.
if
last_save_checkpoint_step
<
total_steps
:
...
...
@@ -493,6 +547,9 @@ class DistributedExecutor(object):
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
self
.
train_summary_writer
.
close
()
self
.
eval_summary_writer
.
close
()
return
train_loss
,
eval_metric_result
def
_run_evaluation
(
self
,
test_step
,
current_training_step
,
metric
,
...
...
@@ -510,9 +567,7 @@ class DistributedExecutor(object):
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
break
metric_result
=
metric
.
result
()
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metric_result
=
metric_result
.
numpy
().
astype
(
float
)
metric_result
=
metric_results
(
metric
)
logging
.
info
(
'Step: [%d] Validation metric = %f'
,
current_training_step
,
metric_result
)
return
metric_result
...
...
@@ -629,7 +684,7 @@ class DistributedExecutor(object):
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
return
eval_metric_result
,
current_step
...
...
official/nlp/README.md
View file @
965cc3ee
...
...
@@ -7,8 +7,9 @@ state-of-the-art models.
The repository contains the following models, with implementations, pre-trained
model weights, usage scripts and conversion utilities:
*
[
Bert
](
bert
)
*
[
Albert
](
albert
)
*
[
Bert
](
bert
)
*
[
NHNet
](
nhnet
)
*
[
XLNet
](
xlnet
)
*
[
Transformer for translation
](
transformer
)
...
...
@@ -16,6 +17,3 @@ Addtional features:
*
Distributed trainable on both multi-GPU and TPU
*
e2e training for custom models, including both pretraining and finetuning.
official/nlp/albert/run_squad.py
View file @
965cc3ee
...
...
@@ -19,9 +19,12 @@ from __future__ import division
from
__future__
import
print_function
import
json
import
os
import
time
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
...
...
@@ -53,7 +56,7 @@ def train_squad(strategy,
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for
a
squad dataset."""
"""Makes predictions for
the
squad dataset."""
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
...
...
@@ -63,6 +66,18 @@ def predict_squad(strategy, input_meta_data):
bert_config
,
squad_lib_sp
)
def
eval_squad
(
strategy
,
input_meta_data
):
"""Evaluate on the squad dataset."""
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
FLAGS
.
sp_model_file
)
eval_metrics
=
run_squad_helper
.
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_sp
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
...
...
@@ -97,10 +112,25 @@ def main(_):
num_gpus
=
FLAGS
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
mode
in
(
'train'
,
'train_and_predict'
):
if
'train'
in
FLAGS
.
mode
:
train_squad
(
strategy
,
input_meta_data
,
run_eagerly
=
FLAGS
.
run_eagerly
)
if
FLAGS
.
mode
in
(
'predict'
,
'train_and_predict'
)
:
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
f1_score
=
eval_metrics
[
'final_f1'
]
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
,
'eval'
)
summary_writer
=
tf
.
summary
.
create_file_writer
(
summary_dir
)
with
summary_writer
.
as_default
():
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
# Also write eval_metrics to json file.
squad_lib_sp
.
write_to_json_files
(
eval_metrics
,
os
.
path
.
join
(
summary_dir
,
'eval_metrics.json'
))
time
.
sleep
(
60
)
if
__name__
==
'__main__'
:
...
...
official/nlp/bert/bert_models.py
View file @
965cc3ee
...
...
@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
sentence_labels
,
sentence_output
)
self
.
add_metric
(
next_sentence_accuracy
,
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
):
if
sentence_labels
is
not
None
:
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
sentence_labels
,
sentence_output
)
self
.
add_metric
(
next_sentence_accuracy
,
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
if
next_sentence_loss
is
not
None
:
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
=
None
):
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
sentence_labels
),
[
0
],
[
1
])
if
sentence_labels
is
not
None
:
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
else
:
sentence_loss
=
None
loss
=
mask_label_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
lm_label_ids
),
[
0
],
[
1
])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
...
...
@@ -120,8 +132,12 @@ def get_transformer_encoder(bert_config,
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,
pooled_output_dim
=
bert_config
.
hidden_size
,
)
# Relies on gin configuration to define the Transformer encoder arguments.
return
transformer_encoder_cls
(
**
kwargs
)
...
...
@@ -151,7 +167,8 @@ def get_transformer_encoder(bert_config,
def
pretrain_model
(
bert_config
,
seq_length
,
max_predictions_per_seq
,
initializer
=
None
):
initializer
=
None
,
use_next_sentence_label
=
True
):
"""Returns model to be used for pre-training.
Args:
...
...
@@ -160,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Pretraining model as well as core BERT submodel from which to save
...
...
@@ -181,8 +199,12 @@ def pretrain_model(bert_config,
shape
=
(
max_predictions_per_seq
,),
name
=
'masked_lm_weights'
,
dtype
=
tf
.
int32
)
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
if
use_next_sentence_label
:
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
else
:
next_sentence_labels
=
None
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
if
initializer
is
None
:
...
...
@@ -202,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size
=
bert_config
.
vocab_size
)
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_labels
)
keras_model
=
tf
.
keras
.
Model
(
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
'next_sentence_labels'
:
next_sentence_labels
,
},
outputs
=
output_loss
)
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
}
if
use_next_sentence_label
:
inputs
[
'next_sentence_labels'
]
=
next_sentence_labels
keras_model
=
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
output_loss
)
return
keras_model
,
transformer_encoder
...
...
@@ -309,8 +332,7 @@ def classifier_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
pooled_output
)
...
...
official/nlp/bert/common_flags.py
View file @
965cc3ee
...
...
@@ -39,7 +39,6 @@ def define_common_bert_flags():
stop_threshold
=
False
,
batch_size
=
False
,
num_gpu
=
True
,
hooks
=
False
,
export_dir
=
False
,
distribution_strategy
=
True
,
run_eagerly
=
True
)
...
...
@@ -63,6 +62,10 @@ def define_common_bert_flags():
'inside.'
)
flags
.
DEFINE_float
(
'learning_rate'
,
5e-5
,
'The initial learning rate for Adam.'
)
flags
.
DEFINE_float
(
'end_lr'
,
0.0
,
'The end learning rate for learning rate decay.'
)
flags
.
DEFINE_string
(
'optimizer_type'
,
'adamw'
,
'The type of optimizer to use for training (adamw|lamb)'
)
flags
.
DEFINE_boolean
(
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
...
...
official/nlp/bert/export_tfhub.py
View file @
965cc3ee
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
typing
import
Text
from
official.nlp.bert
import
bert_models
...
...
@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file"
)
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
...
...
@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
):
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_consumed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
"uncased"
in
vocab_file
,
trainable
=
False
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
)
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
if
__name__
==
"__main__"
:
...
...
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment