Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
965cc3ee
Unverified
Commit
965cc3ee
authored
Apr 21, 2020
by
Ayushman Kumar
Committed by
GitHub
Apr 21, 2020
Browse files
Merge pull request #7 from tensorflow/master
updated
parents
1f3247f4
1f685c54
Changes
222
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
910 additions
and
214 deletions
+910
-214
official/benchmark/models/synthetic_util.py
official/benchmark/models/synthetic_util.py
+129
-0
official/benchmark/ncf_keras_benchmark.py
official/benchmark/ncf_keras_benchmark.py
+1
-3
official/benchmark/owner_utils.py
official/benchmark/owner_utils.py
+67
-0
official/benchmark/owner_utils_test.py
official/benchmark/owner_utils_test.py
+104
-0
official/benchmark/perfzero_benchmark.py
official/benchmark/perfzero_benchmark.py
+0
-0
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+18
-14
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+1
-1
official/benchmark/shakespeare_benchmark.py
official/benchmark/shakespeare_benchmark.py
+3
-3
official/benchmark/squad_evaluate_v1_1.py
official/benchmark/squad_evaluate_v1_1.py
+0
-109
official/benchmark/tfhub_memory_usage_benchmark.py
official/benchmark/tfhub_memory_usage_benchmark.py
+2
-2
official/benchmark/transformer_benchmark.py
official/benchmark/transformer_benchmark.py
+2
-4
official/benchmark/xlnet_benchmark.py
official/benchmark/xlnet_benchmark.py
+1
-1
official/colab/bert.ipynb
official/colab/bert.ipynb
+383
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+9
-4
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+81
-26
official/nlp/README.md
official/nlp/README.md
+2
-4
official/nlp/albert/run_squad.py
official/nlp/albert/run_squad.py
+33
-3
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+57
-35
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+4
-1
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+13
-4
No files found.
official/benchmark/models/synthetic_util.py
0 → 100644
View file @
965cc3ee
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
random
import
string
from
absl
import
logging
import
tensorflow
as
tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class
SyntheticDataset
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_random_name
(),
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
initializers
.
append
(
v
.
initializer
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
else
:
return
self
.
_initializers
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
(
self
,
dataset
):
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
else
:
return
SyntheticDataset
(
dataset
)
def
make_iterator
(
self
,
dataset
):
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
undo_set_up_synthetic_data
():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
official/benchmark/ncf_keras_benchmark.py
View file @
965cc3ee
...
@@ -24,11 +24,10 @@ from absl import flags
...
@@ -24,11 +24,10 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.benchmark
import
benchmark_wrappers
from
official.recommendation
import
ncf_common
from
official.recommendation
import
ncf_common
from
official.recommendation
import
ncf_keras_main
from
official.recommendation
import
ncf_keras_main
from
official.utils.flags
import
core
from
official.utils.flags
import
core
from
official.utils.testing
import
benchmark_wrappers
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
NCF_DATA_DIR_NAME
=
'movielens_data'
NCF_DATA_DIR_NAME
=
'movielens_data'
...
@@ -50,7 +49,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
...
@@ -50,7 +49,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
logging
.
set_verbosity
(
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
ncf_common
.
define_ncf_flags
()
ncf_common
.
define_ncf_flags
()
...
...
official/benchmark/owner_utils.py
0 → 100644
View file @
965cc3ee
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to set Owner annotations on benchmarks.
@owner_utils.Owner('owner_team/user') can be set either at the benchmark class
level / benchmark method level or both.
Runner frameworks can use owner_utils.GetOwner(benchmark_method) to get the
actual owner. Python inheritance for the owner attribute is respected. (E.g
method level owner takes precedence over class level).
See owner_utils_test for associated tests and more examples.
The decorator can be applied both at the method level and at the class level.
Simple example:
===============
class MLBenchmark:
@Owner('example_id')
def benchmark_method_1_gpu(self):
return True
"""
def
Owner
(
owner_name
):
"""Sets the owner attribute on a decorated method or class."""
def
_Wrapper
(
func_or_class
):
"""Sets the benchmark owner attribute."""
func_or_class
.
__benchmark__owner__
=
owner_name
return
func_or_class
return
_Wrapper
def
GetOwner
(
benchmark_method_or_class
):
"""Gets the inherited owner attribute for this benchmark.
Checks for existence of __benchmark__owner__. If it's not present, looks for
it in the parent class's attribute list.
Args:
benchmark_method_or_class: A benchmark method or class.
Returns:
string - the associated owner if present / None.
"""
if
hasattr
(
benchmark_method_or_class
,
'__benchmark__owner__'
):
return
benchmark_method_or_class
.
__benchmark__owner__
elif
hasattr
(
benchmark_method_or_class
,
'__self__'
):
if
hasattr
(
benchmark_method_or_class
.
__self__
,
'__benchmark__owner__'
):
return
benchmark_method_or_class
.
__self__
.
__benchmark__owner__
return
None
official/benchmark/owner_utils_test.py
0 → 100644
View file @
965cc3ee
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.benchmark.owner_utils."""
from
absl.testing
import
absltest
from
official.benchmark
import
owner_utils
@
owner_utils
.
Owner
(
'static_owner'
)
def
static_function
(
foo
=
5
):
return
foo
def
static_function_without_owner
(
foo
=
5
):
return
foo
class
BenchmarkClassWithoutOwner
:
def
method_without_owner
(
self
):
return
100
@
owner_utils
.
Owner
(
'method_owner'
)
def
method_with_owner
(
self
):
return
200
@
owner_utils
.
Owner
(
'class_owner'
)
class
SomeBenchmarkClass
:
def
method_inherited_owner
(
self
):
return
123
@
owner_utils
.
Owner
(
'method_owner'
)
def
method_override_owner
(
self
):
return
345
@
owner_utils
.
Owner
(
'new_class_owner'
)
class
InheritedClass
(
SomeBenchmarkClass
):
def
method_inherited_owner
(
self
):
return
456
@
owner_utils
.
Owner
(
'new_method_owner'
)
def
method_override_owner
(
self
):
return
567
class
OwnerUtilsTest
(
absltest
.
TestCase
):
"""Tests to assert for owner decorator functionality."""
def
test_owner_tag_missing
(
self
):
self
.
assertEqual
(
None
,
owner_utils
.
GetOwner
(
static_function_without_owner
))
benchmark_class
=
BenchmarkClassWithoutOwner
()
self
.
assertEqual
(
None
,
owner_utils
.
GetOwner
(
benchmark_class
.
method_without_owner
))
self
.
assertEqual
(
100
,
benchmark_class
.
method_without_owner
())
self
.
assertEqual
(
'method_owner'
,
owner_utils
.
GetOwner
(
benchmark_class
.
method_with_owner
))
self
.
assertEqual
(
200
,
benchmark_class
.
method_with_owner
())
def
test_owner_attributes_static
(
self
):
self
.
assertEqual
(
'static_owner'
,
owner_utils
.
GetOwner
(
static_function
))
self
.
assertEqual
(
5
,
static_function
(
5
))
def
test_owner_attributes_per_class
(
self
):
level1
=
SomeBenchmarkClass
()
self
.
assertEqual
(
'class_owner'
,
owner_utils
.
GetOwner
(
level1
.
method_inherited_owner
))
self
.
assertEqual
(
123
,
level1
.
method_inherited_owner
())
self
.
assertEqual
(
'method_owner'
,
owner_utils
.
GetOwner
(
level1
.
method_override_owner
))
self
.
assertEqual
(
345
,
level1
.
method_override_owner
())
def
test_owner_attributes_inherited_class
(
self
):
level2
=
InheritedClass
()
self
.
assertEqual
(
'new_class_owner'
,
owner_utils
.
GetOwner
(
level2
.
method_inherited_owner
))
self
.
assertEqual
(
456
,
level2
.
method_inherited_owner
())
self
.
assertEqual
(
'new_method_owner'
,
owner_utils
.
GetOwner
(
level2
.
method_override_owner
))
self
.
assertEqual
(
567
,
level2
.
method_override_owner
())
if
__name__
==
'__main__'
:
absltest
.
main
()
official/
utils/testing
/perfzero_benchmark.py
→
official/
benchmark
/perfzero_benchmark.py
View file @
965cc3ee
File moved
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
965cc3ee
...
@@ -13,19 +13,19 @@
...
@@ -13,19 +13,19 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Executes CTL benchmarks and accuracy tests."""
"""Executes CTL benchmarks and accuracy tests."""
# pylint: disable=line-too-long,g-bad-import-order
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
import
time
import
time
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
resnet_ctl_imagenet_main
from
official.vision.image_classification.resnet
import
resnet_ctl_imagenet_main
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
MIN_TOP_1_ACCURACY
=
0.76
MIN_TOP_1_ACCURACY
=
0.76
...
@@ -53,7 +53,8 @@ class CtlBenchmark(PerfZeroBenchmark):
...
@@ -53,7 +53,8 @@ class CtlBenchmark(PerfZeroBenchmark):
top_1_min
=
None
,
top_1_min
=
None
,
total_batch_size
=
None
,
total_batch_size
=
None
,
log_steps
=
None
,
log_steps
=
None
,
warmup
=
1
):
warmup
=
1
,
start_time_sec
=
None
):
"""Report benchmark results by writing to local protobuf file.
"""Report benchmark results by writing to local protobuf file.
Args:
Args:
...
@@ -64,6 +65,7 @@ class CtlBenchmark(PerfZeroBenchmark):
...
@@ -64,6 +65,7 @@ class CtlBenchmark(PerfZeroBenchmark):
total_batch_size: Global batch-size.
total_batch_size: Global batch-size.
log_steps: How often the log was created for stats['step_timestamp_log'].
log_steps: How often the log was created for stats['step_timestamp_log'].
warmup: number of entries in stats['step_timestamp_log'] to ignore.
warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch.
"""
"""
metrics
=
[]
metrics
=
[]
...
@@ -98,6 +100,12 @@ class CtlBenchmark(PerfZeroBenchmark):
...
@@ -98,6 +100,12 @@ class CtlBenchmark(PerfZeroBenchmark):
'value'
:
stats
[
'avg_exp_per_second'
]
'value'
:
stats
[
'avg_exp_per_second'
]
})
})
if
start_time_sec
and
'step_timestamp_log'
in
stats
:
time_log
=
stats
[
'step_timestamp_log'
]
# time_log[0] is recorded at the beginning of the first step.
startup_time
=
time_log
[
0
].
timestamp
-
start_time_sec
metrics
.
append
({
'name'
:
'startup_time'
,
'value'
:
startup_time
})
flags_str
=
flags_core
.
get_nondefault_flags_as_str
()
flags_str
=
flags_core
.
get_nondefault_flags_as_str
()
self
.
report_benchmark
(
self
.
report_benchmark
(
iters
=-
1
,
iters
=-
1
,
...
@@ -136,8 +144,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -136,8 +144,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu'
)
FLAGS
.
dtype
=
'fp32'
FLAGS
.
dtype
=
'fp32'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_fp16
(
self
):
def
benchmark_8_gpu_fp16
(
self
):
...
@@ -150,8 +156,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -150,8 +156,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
epochs_between_evals
=
10
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_fp16'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_fp16'
)
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_8_gpu_amp
(
self
):
def
benchmark_8_gpu_amp
(
self
):
...
@@ -165,8 +169,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -165,8 +169,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_amp'
)
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
FLAGS
.
fp16_implementation
=
'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS
.
datasets_num_private_threads
=
14
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
@
benchmark_wrappers
.
enable_runtime_flags
@
benchmark_wrappers
.
enable_runtime_flags
...
@@ -181,7 +183,8 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -181,7 +183,8 @@ class Resnet50CtlAccuracy(CtlBenchmark):
top_1_min
=
MIN_TOP_1_ACCURACY
,
top_1_min
=
MIN_TOP_1_ACCURACY
,
top_1_max
=
MAX_TOP_1_ACCURACY
,
top_1_max
=
MAX_TOP_1_ACCURACY
,
total_batch_size
=
FLAGS
.
batch_size
,
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
100
)
log_steps
=
100
,
start_time_sec
=
start_time_sec
)
def
_get_model_dir
(
self
,
folder_name
):
def
_get_model_dir
(
self
,
folder_name
):
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
...
@@ -213,7 +216,8 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -213,7 +216,8 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
wall_time_sec
,
wall_time_sec
,
total_batch_size
=
FLAGS
.
batch_size
,
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
warmup
=
warmup
)
warmup
=
warmup
,
start_time_sec
=
start_time_sec
)
def
benchmark_1_gpu_no_dist_strat
(
self
):
def
benchmark_1_gpu_no_dist_strat
(
self
):
"""Test Keras model with 1 GPU, no distribution strategy."""
"""Test Keras model with 1 GPU, no distribution strategy."""
...
@@ -278,7 +282,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -278,7 +282,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
num_gpus
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_eager'
)
FLAGS
.
batch_size
=
12
8
FLAGS
.
batch_size
=
12
0
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_while_loop
=
False
FLAGS
.
use_tf_while_loop
=
False
FLAGS
.
single_l2_loss_op
=
True
FLAGS
.
single_l2_loss_op
=
True
...
@@ -291,7 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -291,7 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
num_gpus
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_fp16_eager'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_1_gpu_fp16_eager'
)
FLAGS
.
batch_size
=
2
5
0
FLAGS
.
batch_size
=
2
4
0
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_function
=
False
FLAGS
.
use_tf_while_loop
=
False
FLAGS
.
use_tf_while_loop
=
False
...
...
official/benchmark/retinanet_benchmark.py
View file @
965cc3ee
...
@@ -32,7 +32,7 @@ import tensorflow as tf
...
@@ -32,7 +32,7 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
from
official.vision.detection
import
main
as
detection
from
official.vision.detection
import
main
as
detection
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
...
...
official/
staging/shakespe
ar
e
/shakespeare_benchmark.py
→
official/
benchm
ar
k
/shakespeare_benchmark.py
View file @
965cc3ee
...
@@ -23,11 +23,11 @@ import time
...
@@ -23,11 +23,11 @@ import time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.
staging
.shakespeare
import
shakespeare_main
from
official.
benchmark.models
.shakespeare
import
shakespeare_main
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA
=
'shakespeare/shakespeare.txt'
SHAKESPEARE_TRAIN_DATA
=
'shakespeare/shakespeare.txt'
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
...
...
official/benchmark/squad_evaluate_v1_1.py
deleted
100644 → 0
View file @
1f3247f4
# Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and
# Percy Liang. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
string
# pylint: disable=g-bad-import-order
from
absl
import
logging
# pylint: enable=g-bad-import-order
def
_normalize_answer
(
s
):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r
"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_f1_score
(
prediction
,
ground_truth
):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens
=
_normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
_normalize_answer
(
ground_truth
).
split
()
prediction_counter
=
collections
.
Counter
(
prediction_tokens
)
ground_truth_counter
=
collections
.
Counter
(
ground_truth_tokens
)
common
=
prediction_counter
&
ground_truth_counter
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_exact_match_score
(
prediction
,
ground_truth
):
"""Checks if predicted answer exactly matches ground truth answer."""
return
_normalize_answer
(
prediction
)
==
_normalize_answer
(
ground_truth
)
def
_metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Computes the max over all metric scores."""
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
evaluate
(
dataset
,
predictions
):
"""Evaluates predictions for a dataset."""
f1
=
exact_match
=
total
=
0
for
article
in
dataset
:
for
paragraph
in
article
[
"paragraphs"
]:
for
qa
in
paragraph
[
"qas"
]:
total
+=
1
if
qa
[
"id"
]
not
in
predictions
:
message
=
"Unanswered question "
+
qa
[
"id"
]
+
" will receive score 0."
logging
.
error
(
message
)
continue
ground_truths
=
[
entry
[
"text"
]
for
entry
in
qa
[
"answers"
]]
prediction
=
predictions
[
qa
[
"id"
]]
exact_match
+=
_metric_max_over_ground_truths
(
_exact_match_score
,
prediction
,
ground_truths
)
f1
+=
_metric_max_over_ground_truths
(
_f1_score
,
prediction
,
ground_truths
)
exact_match
=
exact_match
/
total
f1
=
f1
/
total
return
{
"exact_match"
:
exact_match
,
"f1"
:
f1
}
official/benchmark/tfhub_memory_usage_benchmark.py
View file @
965cc3ee
...
@@ -20,10 +20,10 @@ import functools
...
@@ -20,10 +20,10 @@ import functools
import
time
import
time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
from
official.
utils.testing
.perfzero_benchmark
import
PerfZeroBenchmark
from
official.
benchmark
.perfzero_benchmark
import
PerfZeroBenchmark
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/benchmark/transformer_benchmark.py
View file @
965cc3ee
...
@@ -22,12 +22,11 @@ import time
...
@@ -22,12 +22,11 @@ import time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.benchmark
import
benchmark_wrappers
from
official.benchmark.perfzero_benchmark
import
PerfZeroBenchmark
from
official.nlp.transformer
import
misc
from
official.nlp.transformer
import
misc
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.nlp.transformer
import
transformer_main
as
transformer_main
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.testing
import
benchmark_wrappers
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
...
@@ -44,7 +43,6 @@ class TransformerBenchmark(PerfZeroBenchmark):
...
@@ -44,7 +43,6 @@ class TransformerBenchmark(PerfZeroBenchmark):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
root_data_dir
=
None
,
flag_methods
=
None
):
flag_methods
=
None
):
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
root_data_dir
=
root_data_dir
if
root_data_dir
else
''
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
self
.
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
...
...
official/benchmark/xlnet_benchmark.py
View file @
965cc3ee
...
@@ -31,7 +31,7 @@ import tensorflow as tf
...
@@ -31,7 +31,7 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_squad
from
official.nlp.xlnet
import
run_squad
from
official.
utils.testing
import
benchmark_wrappers
from
official.
benchmark
import
benchmark_wrappers
# pylint: disable=line-too-long
# pylint: disable=line-too-long
...
...
official/colab/bert.ipynb
0 → 100644
View file @
965cc3ee
This diff is collapsed.
Click to expand it.
official/modeling/hyperparams/base_config.py
View file @
965cc3ee
...
@@ -257,10 +257,8 @@ class RuntimeConfig(Config):
...
@@ -257,10 +257,8 @@ class RuntimeConfig(Config):
Attributes:
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
created for all datasets computation.
...
@@ -272,11 +270,13 @@ class RuntimeConfig(Config):
...
@@ -272,11 +270,13 @@ class RuntimeConfig(Config):
all_reduce_alg: Defines the algorithm for performing all-reduce.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
"""
"""
distribution_strategy
:
str
=
'mirrored'
distribution_strategy
:
str
=
'mirrored'
enable_eager
:
bool
=
False
enable_xla
:
bool
=
False
enable_xla
:
bool
=
False
gpu_threads_enabled
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
per_gpu_thread_count
:
int
=
0
...
@@ -286,6 +286,8 @@ class RuntimeConfig(Config):
...
@@ -286,6 +286,8 @@ class RuntimeConfig(Config):
task_index
:
int
=
-
1
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
num_packs
:
int
=
1
loss_scale
:
Optional
[
str
]
=
None
run_eagerly
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -312,7 +314,10 @@ class CallbacksConfig(Config):
...
@@ -312,7 +314,10 @@ class CallbacksConfig(Config):
Callback. Defaults to True.
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
"""
enable_checkpoint_and_export
:
bool
=
True
enable_checkpoint_and_export
:
bool
=
True
enable_tensorboard
:
bool
=
True
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
official/modeling/training/distributed_executor.py
View file @
965cc3ee
...
@@ -19,7 +19,6 @@ from __future__ import division
...
@@ -19,7 +19,6 @@ from __future__ import division
# from __future__ import google_type_annotations
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
__future__
import
print_function
import
json
import
os
import
os
from
absl
import
flags
from
absl
import
flags
...
@@ -31,8 +30,9 @@ import tensorflow as tf
...
@@ -31,8 +30,9 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.utils.misc
import
distribution_utils
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -59,6 +59,45 @@ def _no_metric():
...
@@ -59,6 +59,45 @@ def _no_metric():
return
None
return
None
def
metrics_as_dict
(
metric
):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be a object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metrics
=
{
metric
.
name
:
metric
}
elif
isinstance
(
metric
,
list
):
metrics
=
{
m
.
name
:
m
for
m
in
metric
}
elif
isinstance
(
metric
,
dict
):
metrics
=
metric
elif
not
metric
:
return
{}
else
:
metrics
=
{
'metric'
:
metric
}
return
metrics
def
metric_results
(
metric
):
"""Collects results from the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
metric_result
=
{
name
:
m
.
result
().
numpy
().
astype
(
float
)
for
name
,
m
in
metrics
.
items
()
}
return
metric_result
def
reset_states
(
metric
):
"""Resets states of the given metric(s)."""
metrics
=
metrics_as_dict
(
metric
)
for
m
in
metrics
.
values
():
m
.
reset_states
()
class
SummaryWriter
(
object
):
class
SummaryWriter
(
object
):
"""Simple SummaryWriter for writing dictionary of metrics.
"""Simple SummaryWriter for writing dictionary of metrics.
...
@@ -185,6 +224,7 @@ class DistributedExecutor(object):
...
@@ -185,6 +224,7 @@ class DistributedExecutor(object):
loss_fn
,
loss_fn
,
optimizer
,
optimizer
,
metric
=
None
):
metric
=
None
):
metrics
=
metrics_as_dict
(
metric
)
def
_replicated_step
(
inputs
):
def
_replicated_step
(
inputs
):
"""Replicated training step."""
"""Replicated training step."""
...
@@ -195,11 +235,8 @@ class DistributedExecutor(object):
...
@@ -195,11 +235,8 @@ class DistributedExecutor(object):
prediction_loss
=
loss_fn
(
labels
,
outputs
)
prediction_loss
=
loss_fn
(
labels
,
outputs
)
loss
=
tf
.
reduce_mean
(
prediction_loss
)
loss
=
tf
.
reduce_mean
(
prediction_loss
)
loss
=
loss
/
strategy
.
num_replicas_in_sync
loss
=
loss
/
strategy
.
num_replicas_in_sync
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
for
m
in
metrics
.
values
():
metric
.
update_state
(
labels
,
outputs
)
m
.
update_state
(
labels
,
outputs
)
else
:
logging
.
error
(
'train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
...
@@ -235,6 +272,7 @@ class DistributedExecutor(object):
...
@@ -235,6 +272,7 @@ class DistributedExecutor(object):
Args:
Args:
iterator: an iterator that yields input tensors.
iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns:
Returns:
The loss tensor.
The loss tensor.
...
@@ -259,6 +297,7 @@ class DistributedExecutor(object):
...
@@ -259,6 +297,7 @@ class DistributedExecutor(object):
def
_create_test_step
(
self
,
strategy
,
model
,
metric
):
def
_create_test_step
(
self
,
strategy
,
model
,
metric
):
"""Creates a distributed test step."""
"""Creates a distributed test step."""
metrics
=
metrics_as_dict
(
metric
)
@
tf
.
function
@
tf
.
function
def
test_step
(
iterator
):
def
test_step
(
iterator
):
...
@@ -266,22 +305,20 @@ class DistributedExecutor(object):
...
@@ -266,22 +305,20 @@ class DistributedExecutor(object):
if
not
metric
:
if
not
metric
:
logging
.
info
(
'Skip test_step because metric is None (%s)'
,
metric
)
logging
.
info
(
'Skip test_step because metric is None (%s)'
,
metric
)
return
None
,
None
return
None
,
None
if
not
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
raise
ValueError
(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'
.
format
(
metric
))
def
_test_step_fn
(
inputs
):
def
_test_step_fn
(
inputs
):
"""Replicated accuracy calculation."""
"""Replicated accuracy calculation."""
inputs
,
labels
=
inputs
inputs
,
labels
=
inputs
model_outputs
=
model
(
inputs
,
training
=
False
)
model_outputs
=
model
(
inputs
,
training
=
False
)
metric
.
update_state
(
labels
,
model_outputs
)
for
m
in
metrics
.
values
():
m
.
update_state
(
labels
,
model_outputs
)
return
labels
,
model_outputs
return
labels
,
model_outputs
return
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
return
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
return
test_step
return
test_step
def
train
(
self
,
def
train
(
self
,
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
train_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
tf
.
data
.
Dataset
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
eval_input_fn
:
Callable
[[
params_dict
.
ParamsDict
],
...
@@ -330,10 +367,12 @@ class DistributedExecutor(object):
...
@@ -330,10 +367,12 @@ class DistributedExecutor(object):
eval_metric_fn
=
eval_metric_fn
or
_no_metric
eval_metric_fn
=
eval_metric_fn
or
_no_metric
if
custom_callbacks
and
iterations_per_loop
!=
1
:
if
custom_callbacks
and
iterations_per_loop
!=
1
:
logging
.
error
(
logging
.
warning
(
'It is sematically wrong to run callbacks when '
'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)'
,
iterations_per_loop
)
'iterations_per_loop is not one (%s)'
,
iterations_per_loop
)
custom_callbacks
=
custom_callbacks
or
[]
def
_run_callbacks_on_batch_begin
(
batch
):
def
_run_callbacks_on_batch_begin
(
batch
):
"""Runs custom callbacks at the start of every step."""
"""Runs custom callbacks at the start of every step."""
if
not
custom_callbacks
:
if
not
custom_callbacks
:
...
@@ -402,6 +441,11 @@ class DistributedExecutor(object):
...
@@ -402,6 +441,11 @@ class DistributedExecutor(object):
test_summary_writer
=
summary_writer_fn
(
model_dir
,
'eval_test'
)
test_summary_writer
=
summary_writer_fn
(
model_dir
,
'eval_test'
)
self
.
eval_summary_writer
=
test_summary_writer
.
writer
self
.
eval_summary_writer
=
test_summary_writer
.
writer
# Use training summary writer in TimeHistory if it's in use
for
cb
in
custom_callbacks
:
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
):
cb
.
summary_writer
=
self
.
train_summary_writer
# Continue training loop.
# Continue training loop.
train_step
=
self
.
_create_train_step
(
train_step
=
self
.
_create_train_step
(
strategy
=
strategy
,
strategy
=
strategy
,
...
@@ -414,6 +458,20 @@ class DistributedExecutor(object):
...
@@ -414,6 +458,20 @@ class DistributedExecutor(object):
self
.
global_train_step
=
model
.
optimizer
.
iterations
self
.
global_train_step
=
model
.
optimizer
.
iterations
test_step
=
self
.
_create_test_step
(
strategy
,
model
,
metric
=
eval_metric
)
test_step
=
self
.
_create_test_step
(
strategy
,
model
,
metric
=
eval_metric
)
# Step-0 operations
if
current_step
==
0
and
not
latest_checkpoint_file
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
reset_states
(
eval_metric
)
logging
.
info
(
'Training started'
)
logging
.
info
(
'Training started'
)
last_save_checkpoint_step
=
current_step
last_save_checkpoint_step
=
current_step
while
current_step
<
total_steps
:
while
current_step
<
total_steps
:
...
@@ -422,23 +480,19 @@ class DistributedExecutor(object):
...
@@ -422,23 +480,19 @@ class DistributedExecutor(object):
_run_callbacks_on_batch_begin
(
current_step
)
_run_callbacks_on_batch_begin
(
current_step
)
train_loss
=
train_step
(
train_iterator
,
train_loss
=
train_step
(
train_iterator
,
tf
.
convert_to_tensor
(
num_steps
,
dtype
=
tf
.
int32
))
tf
.
convert_to_tensor
(
num_steps
,
dtype
=
tf
.
int32
))
_run_callbacks_on_batch_end
(
current_step
)
current_step
+=
num_steps
current_step
+=
num_steps
train_loss
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_loss
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_loss
)
train_loss
)
_run_callbacks_on_batch_end
(
current_step
-
1
)
if
not
isinstance
(
train_loss
,
dict
):
if
not
isinstance
(
train_loss
,
dict
):
train_loss
=
{
'total_loss'
:
train_loss
}
train_loss
=
{
'total_loss'
:
train_loss
}
if
np
.
isnan
(
train_loss
[
'total_loss'
]):
if
np
.
isnan
(
train_loss
[
'total_loss'
]):
raise
ValueError
(
'total loss is NaN.'
)
raise
ValueError
(
'total loss is NaN.'
)
if
train_metric
:
if
train_metric
:
train_metric_result
=
train_metric
.
result
()
train_metric_result
=
metric_results
(
train_metric
)
if
isinstance
(
train_metric
,
tf
.
keras
.
metrics
.
Metric
):
train_metric_result
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
().
astype
(
float
),
train_metric_result
)
if
not
isinstance
(
train_metric_result
,
dict
):
train_metric_result
=
{
'metric'
:
train_metric_result
}
train_metric_result
.
update
(
train_loss
)
train_metric_result
.
update
(
train_loss
)
else
:
else
:
train_metric_result
=
train_loss
train_metric_result
=
train_loss
...
@@ -475,9 +529,9 @@ class DistributedExecutor(object):
...
@@ -475,9 +529,9 @@ class DistributedExecutor(object):
# Re-initialize evaluation metric, except the last step.
# Re-initialize evaluation metric, except the last step.
if
eval_metric
and
current_step
<
total_steps
:
if
eval_metric
and
current_step
<
total_steps
:
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
if
train_metric
and
current_step
<
total_steps
:
if
train_metric
and
current_step
<
total_steps
:
train_metric
.
reset_states
()
reset_states
(
train_metric
)
# Reaches the end of training and saves the last checkpoint.
# Reaches the end of training and saves the last checkpoint.
if
last_save_checkpoint_step
<
total_steps
:
if
last_save_checkpoint_step
<
total_steps
:
...
@@ -493,6 +547,9 @@ class DistributedExecutor(object):
...
@@ -493,6 +547,9 @@ class DistributedExecutor(object):
test_summary_writer
(
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
self
.
train_summary_writer
.
close
()
self
.
eval_summary_writer
.
close
()
return
train_loss
,
eval_metric_result
return
train_loss
,
eval_metric_result
def
_run_evaluation
(
self
,
test_step
,
current_training_step
,
metric
,
def
_run_evaluation
(
self
,
test_step
,
current_training_step
,
metric
,
...
@@ -510,9 +567,7 @@ class DistributedExecutor(object):
...
@@ -510,9 +567,7 @@ class DistributedExecutor(object):
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
break
break
metric_result
=
metric
.
result
()
metric_result
=
metric_results
(
metric
)
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
metric_result
=
metric_result
.
numpy
().
astype
(
float
)
logging
.
info
(
'Step: [%d] Validation metric = %f'
,
current_training_step
,
logging
.
info
(
'Step: [%d] Validation metric = %f'
,
current_training_step
,
metric_result
)
metric_result
)
return
metric_result
return
metric_result
...
@@ -629,7 +684,7 @@ class DistributedExecutor(object):
...
@@ -629,7 +684,7 @@ class DistributedExecutor(object):
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
eval_metric_result
)
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
summary_writer
(
metrics
=
eval_metric_result
,
step
=
current_step
)
eval_metric
.
reset_states
()
reset_states
(
eval_metric
)
return
eval_metric_result
,
current_step
return
eval_metric_result
,
current_step
...
...
official/nlp/README.md
View file @
965cc3ee
...
@@ -7,8 +7,9 @@ state-of-the-art models.
...
@@ -7,8 +7,9 @@ state-of-the-art models.
The repository contains the following models, with implementations, pre-trained
The repository contains the following models, with implementations, pre-trained
model weights, usage scripts and conversion utilities:
model weights, usage scripts and conversion utilities:
*
[
Bert
](
bert
)
*
[
Albert
](
albert
)
*
[
Albert
](
albert
)
*
[
Bert
](
bert
)
*
[
NHNet
](
nhnet
)
*
[
XLNet
](
xlnet
)
*
[
XLNet
](
xlnet
)
*
[
Transformer for translation
](
transformer
)
*
[
Transformer for translation
](
transformer
)
...
@@ -16,6 +17,3 @@ Addtional features:
...
@@ -16,6 +17,3 @@ Addtional features:
*
Distributed trainable on both multi-GPU and TPU
*
Distributed trainable on both multi-GPU and TPU
*
e2e training for custom models, including both pretraining and finetuning.
*
e2e training for custom models, including both pretraining and finetuning.
official/nlp/albert/run_squad.py
View file @
965cc3ee
...
@@ -19,9 +19,12 @@ from __future__ import division
...
@@ -19,9 +19,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
json
import
json
import
os
import
time
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.albert
import
configs
as
albert_configs
...
@@ -53,7 +56,7 @@ def train_squad(strategy,
...
@@ -53,7 +56,7 @@ def train_squad(strategy,
def
predict_squad
(
strategy
,
input_meta_data
):
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for
a
squad dataset."""
"""Makes predictions for
the
squad dataset."""
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
...
@@ -63,6 +66,18 @@ def predict_squad(strategy, input_meta_data):
...
@@ -63,6 +66,18 @@ def predict_squad(strategy, input_meta_data):
bert_config
,
squad_lib_sp
)
bert_config
,
squad_lib_sp
)
def
eval_squad
(
strategy
,
input_meta_data
):
"""Evaluate on the squad dataset."""
bert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
FLAGS
.
sp_model_file
)
eval_metrics
=
run_squad_helper
.
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_sp
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
):
def
export_squad
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
"""Exports a trained model as a `SavedModel` for inference.
...
@@ -97,10 +112,25 @@ def main(_):
...
@@ -97,10 +112,25 @@ def main(_):
num_gpus
=
FLAGS
.
num_gpus
,
num_gpus
=
FLAGS
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
mode
in
(
'train'
,
'train_and_predict'
):
if
'train'
in
FLAGS
.
mode
:
train_squad
(
strategy
,
input_meta_data
,
run_eagerly
=
FLAGS
.
run_eagerly
)
train_squad
(
strategy
,
input_meta_data
,
run_eagerly
=
FLAGS
.
run_eagerly
)
if
FLAGS
.
mode
in
(
'predict'
,
'train_and_predict'
)
:
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
f1_score
=
eval_metrics
[
'final_f1'
]
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
,
'eval'
)
summary_writer
=
tf
.
summary
.
create_file_writer
(
summary_dir
)
with
summary_writer
.
as_default
():
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
# Also write eval_metrics to json file.
squad_lib_sp
.
write_to_json_files
(
eval_metrics
,
os
.
path
.
join
(
summary_dir
,
'eval_metrics.json'
))
time
.
sleep
(
60
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/bert/bert_models.py
View file @
965cc3ee
...
@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
if
sentence_labels
is
not
None
:
sentence_labels
,
sentence_output
)
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
self
.
add_metric
(
sentence_labels
,
sentence_output
)
next_sentence_accuracy
,
self
.
add_metric
(
name
=
'next_sentence_accuracy'
,
next_sentence_accuracy
,
aggregation
=
'mean'
)
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
if
next_sentence_loss
is
not
None
:
self
.
add_metric
(
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
sentence_labels
):
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
=
None
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
if
sentence_labels
is
not
None
:
loss
=
mask_label_loss
+
sentence_loss
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
batch_shape
=
tf
.
slice
(
tf
.
shape
(
sentence_labels
),
[
0
],
[
1
])
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
else
:
sentence_loss
=
None
loss
=
mask_label_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
lm_label_ids
),
[
0
],
[
1
])
# TODO(hongkuny): Avoids the hack and switches add_loss.
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
...
@@ -120,8 +132,12 @@ def get_transformer_encoder(bert_config,
...
@@ -120,8 +132,12 @@ def get_transformer_encoder(bert_config,
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
)
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
kwargs
=
dict
(
num_hidden_instances
=
bert_config
.
num_hidden_layers
,)
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,
pooled_output_dim
=
bert_config
.
hidden_size
,
)
# Relies on gin configuration to define the Transformer encoder arguments.
# Relies on gin configuration to define the Transformer encoder arguments.
return
transformer_encoder_cls
(
**
kwargs
)
return
transformer_encoder_cls
(
**
kwargs
)
...
@@ -151,7 +167,8 @@ def get_transformer_encoder(bert_config,
...
@@ -151,7 +167,8 @@ def get_transformer_encoder(bert_config,
def
pretrain_model
(
bert_config
,
def
pretrain_model
(
bert_config
,
seq_length
,
seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
initializer
=
None
):
initializer
=
None
,
use_next_sentence_label
=
True
):
"""Returns model to be used for pre-training.
"""Returns model to be used for pre-training.
Args:
Args:
...
@@ -160,6 +177,7 @@ def pretrain_model(bert_config,
...
@@ -160,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Returns:
Pretraining model as well as core BERT submodel from which to save
Pretraining model as well as core BERT submodel from which to save
...
@@ -181,8 +199,12 @@ def pretrain_model(bert_config,
...
@@ -181,8 +199,12 @@ def pretrain_model(bert_config,
shape
=
(
max_predictions_per_seq
,),
shape
=
(
max_predictions_per_seq
,),
name
=
'masked_lm_weights'
,
name
=
'masked_lm_weights'
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
if
use_next_sentence_label
:
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
else
:
next_sentence_labels
=
None
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
if
initializer
is
None
:
if
initializer
is
None
:
...
@@ -202,17 +224,18 @@ def pretrain_model(bert_config,
...
@@ -202,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size
=
bert_config
.
vocab_size
)
vocab_size
=
bert_config
.
vocab_size
)
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_labels
)
masked_lm_weights
,
next_sentence_labels
)
keras_model
=
tf
.
keras
.
Model
(
inputs
=
{
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
'masked_lm_weights'
:
masked_lm_weights
,
}
'next_sentence_labels'
:
next_sentence_labels
,
if
use_next_sentence_label
:
},
inputs
[
'next_sentence_labels'
]
=
next_sentence_labels
outputs
=
output_loss
)
keras_model
=
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
output_loss
)
return
keras_model
,
transformer_encoder
return
keras_model
,
transformer_encoder
...
@@ -309,8 +332,7 @@ def classifier_model(bert_config,
...
@@ -309,8 +332,7 @@ def classifier_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
hub_module_url
,
trainable
=
hub_module_trainable
)
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
pooled_output
)
pooled_output
)
...
...
official/nlp/bert/common_flags.py
View file @
965cc3ee
...
@@ -39,7 +39,6 @@ def define_common_bert_flags():
...
@@ -39,7 +39,6 @@ def define_common_bert_flags():
stop_threshold
=
False
,
stop_threshold
=
False
,
batch_size
=
False
,
batch_size
=
False
,
num_gpu
=
True
,
num_gpu
=
True
,
hooks
=
False
,
export_dir
=
False
,
export_dir
=
False
,
distribution_strategy
=
True
,
distribution_strategy
=
True
,
run_eagerly
=
True
)
run_eagerly
=
True
)
...
@@ -63,6 +62,10 @@ def define_common_bert_flags():
...
@@ -63,6 +62,10 @@ def define_common_bert_flags():
'inside.'
)
'inside.'
)
flags
.
DEFINE_float
(
'learning_rate'
,
5e-5
,
flags
.
DEFINE_float
(
'learning_rate'
,
5e-5
,
'The initial learning rate for Adam.'
)
'The initial learning rate for Adam.'
)
flags
.
DEFINE_float
(
'end_lr'
,
0.0
,
'The end learning rate for learning rate decay.'
)
flags
.
DEFINE_string
(
'optimizer_type'
,
'adamw'
,
'The type of optimizer to use for training (adamw|lamb)'
)
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
'scale_loss'
,
False
,
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
'Whether to divide the loss by number of replica inside the per-replica '
...
...
official/nlp/bert/export_tfhub.py
View file @
965cc3ee
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Text
from
typing
import
Text
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
bert_models
...
@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
...
@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file"
)
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
...
@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
...
@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
):
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_consumed
()
checkpoint
.
restore
(
model_checkpoint_path
).
assert_consumed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
"uncased"
in
vocab_file
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
def
main
(
_
):
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
)
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment