Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
772
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
159 additions
and
1517 deletions
+159
-1517
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+0
-125
official/nlp/bert/configs.py
official/nlp/bert/configs.py
+0
-104
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+0
-139
official/nlp/bert/export_tfhub_test.py
official/nlp/bert/export_tfhub_test.py
+0
-108
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+0
-515
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+0
-148
official/nlp/bert/tf1_checkpoint_converter_lib.py
official/nlp/bert/tf1_checkpoint_converter_lib.py
+0
-201
official/nlp/bert/tf2_encoder_checkpoint_converter.py
official/nlp/bert/tf2_encoder_checkpoint_converter.py
+0
-160
official/nlp/configs/__init__.py
official/nlp/configs/__init__.py
+1
-1
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+3
-1
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+1
-1
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+98
-4
official/nlp/configs/encoders_test.py
official/nlp/configs/encoders_test.py
+2
-2
official/nlp/configs/experiment_configs.py
official/nlp/configs/experiment_configs.py
+1
-2
official/nlp/configs/experiments/wiki_books_pretrain.yaml
official/nlp/configs/experiments/wiki_books_pretrain.yaml
+48
-0
official/nlp/configs/finetuning_experiments.py
official/nlp/configs/finetuning_experiments.py
+1
-1
official/nlp/configs/pretraining_experiments.py
official/nlp/configs/pretraining_experiments.py
+1
-1
official/nlp/configs/wmt_transformer_experiments.py
official/nlp/configs/wmt_transformer_experiments.py
+1
-2
official/nlp/continuous_finetune_lib.py
official/nlp/continuous_finetune_lib.py
+1
-1
official/nlp/continuous_finetune_lib_test.py
official/nlp/continuous_finetune_lib_test.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
772 of 772+
files are displayed.
Plain diff
Email patch
official/nlp/bert/common_flags.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defining common flags used across all BERT models/applications."""
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils
import
hyperparams_flags
from
official.utils.flags
import
core
as
flags_core
def
define_common_bert_flags
():
"""Define common flags for BERT tasks."""
flags_core
.
define_base
(
data_dir
=
False
,
model_dir
=
True
,
clean
=
False
,
train_epochs
=
False
,
epochs_between_evals
=
False
,
stop_threshold
=
False
,
batch_size
=
False
,
num_gpu
=
True
,
export_dir
=
False
,
distribution_strategy
=
True
,
run_eagerly
=
True
)
flags_core
.
define_distribution
()
flags
.
DEFINE_string
(
'bert_config_file'
,
None
,
'Bert configuration file to define core bert layers.'
)
flags
.
DEFINE_string
(
'model_export_path'
,
None
,
'Path to the directory, where trainined model will be '
'exported.'
)
flags
.
DEFINE_string
(
'tpu'
,
''
,
'TPU address to connect to.'
)
flags
.
DEFINE_string
(
'init_checkpoint'
,
None
,
'Initial checkpoint (usually from a pre-trained BERT model).'
)
flags
.
DEFINE_integer
(
'num_train_epochs'
,
3
,
'Total number of training epochs to perform.'
)
flags
.
DEFINE_integer
(
'steps_per_loop'
,
None
,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside. If not set the value will be configured depending on the '
'devices available.'
)
flags
.
DEFINE_float
(
'learning_rate'
,
5e-5
,
'The initial learning rate for Adam.'
)
flags
.
DEFINE_float
(
'end_lr'
,
0.0
,
'The end learning rate for learning rate decay.'
)
flags
.
DEFINE_string
(
'optimizer_type'
,
'adamw'
,
'The type of optimizer to use for training (adamw|lamb)'
)
flags
.
DEFINE_boolean
(
'scale_loss'
,
False
,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.'
)
flags
.
DEFINE_boolean
(
'use_keras_compile_fit'
,
False
,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.'
)
flags
.
DEFINE_string
(
'hub_module_url'
,
None
,
'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.'
)
flags
.
DEFINE_bool
(
'hub_module_trainable'
,
True
,
'True to make keras layers in the hub module trainable.'
)
flags
.
DEFINE_string
(
'sub_model_export_name'
,
None
,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.'
)
flags
.
DEFINE_bool
(
'explicit_allreduce'
,
False
,
'True to use explicit allreduce instead of the implicit '
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.'
)
flags
.
DEFINE_integer
(
'allreduce_bytes_per_pack'
,
0
,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.'
)
flags_core
.
define_log_steps
()
# Adds flags for mixed precision and multi-worker training.
flags_core
.
define_performance
(
num_parallel_calls
=
False
,
inter_op
=
False
,
intra_op
=
False
,
synthetic_data
=
False
,
max_train_steps
=
False
,
dtype
=
True
,
loss_scale
=
True
,
all_reduce_alg
=
True
,
num_packs
=
False
,
tf_gpu_thread_mode
=
True
,
datasets_num_private_threads
=
True
,
enable_xla
=
True
,
fp16_implementation
=
True
,
)
# Adds gin configuration flags.
hyperparams_flags
.
define_gin_flags
()
def
dtype
():
return
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
)
def
use_float16
():
return
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
)
==
tf
.
float16
def
get_loss_scale
():
return
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
'dynamic'
)
official/nlp/bert/configs.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The main BERT model and related functions."""
import
copy
import
json
import
six
import
tensorflow
as
tf
class
BertConfig
(
object
):
"""Configuration for `BertModel`."""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
initializer_range
=
0.02
,
embedding_size
=
None
,
backward_compatible
=
True
):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: (Optional) width of the factorized word embeddings.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
embedding_size
=
embedding_size
self
.
backward_compatible
=
backward_compatible
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config
=
BertConfig
(
vocab_size
=
None
)
for
(
key
,
value
)
in
six
.
iteritems
(
json_object
):
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `BertConfig` from a json file of parameters."""
with
tf
.
io
.
gfile
.
GFile
(
json_file
,
"r"
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
official/nlp/bert/export_tfhub.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export BERT as a TF-Hub SavedModel.
This script is **DEPRECATED** for exporting BERT encoder models;
see the error message in by main() for details.
"""
from
typing
import
Text
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
configs
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
"model_checkpoint_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file"
)
flags
.
DEFINE_enum
(
"model_type"
,
"encoder"
,
[
"encoder"
,
"squad"
],
"What kind of BERT model to export."
)
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A `BertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
)
transformer_encoder
=
bert_models
.
get_transformer_encoder
(
bert_config
,
sequence_length
=
None
)
sequence_output
,
pooled_output
=
transformer_encoder
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
pooled_output
,
sequence_output
]),
transformer_encoder
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
,
# Legacy checkpoints.
encoder
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
export_bert_squad_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
span_labeling
,
_
=
bert_models
.
squad_model
(
bert_config
,
max_seq_length
=
None
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
span_labeling
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
span_labeling
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
span_labeling
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
span_labeling
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
model_type
==
"encoder"
:
deprecation_note
=
(
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
"(and other) encoders with dict inputs/outputs conforming to "
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
)
logging
.
error
(
deprecation_note
)
print
(
"
\n\n
NOTICE:"
,
deprecation_note
,
"
\n
"
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
elif
FLAGS
.
model_type
==
"squad"
:
export_bert_squad_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported model_type %s."
%
FLAGS
.
model_type
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/nlp/bert/export_tfhub_test.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests official.nlp.bert.export_tfhub."""
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
export_tfhub
class
ExportTfhubTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
"model"
,
"encoder"
)
def
test_export_tfhub
(
self
,
ckpt_key_name
):
# Exports a savedmodel for TF-Hub
hidden_size
=
16
bert_config
=
configs
.
BertConfig
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
bert_model
,
encoder
=
export_tfhub
.
create_bert_model
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
**
{
ckpt_key_name
:
encoder
})
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
hub_destination
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub
.
export_bert_tfhub
(
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
)
# Restores a hub KerasLayer.
hub_layer
=
hub
.
KerasLayer
(
hub_destination
,
trainable
=
True
)
if
hasattr
(
hub_layer
,
"resolved_object"
):
# Checks meta attributes.
self
.
assertTrue
(
hub_layer
.
resolved_object
.
do_lower_case
.
numpy
())
with
tf
.
io
.
gfile
.
GFile
(
hub_layer
.
resolved_object
.
vocab_file
.
asset_path
.
numpy
())
as
f
:
self
.
assertEqual
(
"dummy content"
,
f
.
read
())
# Checks the hub KerasLayer.
for
source_weight
,
hub_weight
in
zip
(
bert_model
.
trainable_weights
,
hub_layer
.
trainable_weights
):
self
.
assertAllClose
(
source_weight
.
numpy
(),
hub_weight
.
numpy
())
seq_length
=
10
dummy_ids
=
np
.
zeros
((
2
,
seq_length
),
dtype
=
np
.
int32
)
hub_outputs
=
hub_layer
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
source_outputs
=
bert_model
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs
=
reversed
(
encoder
([
dummy_ids
,
dummy_ids
,
dummy_ids
]))
self
.
assertEqual
(
hub_outputs
[
0
].
shape
,
(
2
,
hidden_size
))
self
.
assertEqual
(
hub_outputs
[
1
].
shape
,
(
2
,
seq_length
,
hidden_size
))
for
source_output
,
hub_output
,
encoder_output
in
zip
(
source_outputs
,
hub_outputs
,
encoder_outputs
):
self
.
assertAllClose
(
source_output
.
numpy
(),
hub_output
.
numpy
())
self
.
assertAllClose
(
source_output
.
numpy
(),
encoder_output
.
numpy
())
# Test that training=True makes a difference (activates dropout).
def
_dropout_mean_stddev
(
training
,
num_runs
=
20
):
input_ids
=
np
.
array
([[
14
,
12
,
42
,
95
,
99
]],
np
.
int32
)
inputs
=
[
input_ids
,
np
.
ones_like
(
input_ids
),
np
.
zeros_like
(
input_ids
)]
outputs
=
np
.
concatenate
(
[
hub_layer
(
inputs
,
training
=
training
)[
0
]
for
_
in
range
(
num_runs
)])
return
np
.
mean
(
np
.
std
(
outputs
,
axis
=
0
))
self
.
assertLess
(
_dropout_mean_stddev
(
training
=
False
),
1e-6
)
self
.
assertGreater
(
_dropout_mean_stddev
(
training
=
True
),
1e-3
)
# Test propagation of seq_length in shape inference.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
pooled_output
,
sequence_output
=
hub_layer
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
self
.
assertEqual
(
pooled_output
.
shape
.
as_list
(),
[
None
,
hidden_size
])
self
.
assertEqual
(
sequence_output
.
shape
.
as_list
(),
[
None
,
seq_length
,
hidden_size
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/bert/run_classifier.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT classification or regression finetuning runner in TF 2.x."""
import
functools
import
json
import
math
import
os
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
configs
as
bert_configs
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_enum
(
'mode'
,
'train_and_eval'
,
[
'train_and_eval'
,
'export_only'
,
'predict'
],
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.'
)
flags
.
DEFINE_string
(
'train_data_path'
,
None
,
'Path to training data for BERT classifier.'
)
flags
.
DEFINE_string
(
'eval_data_path'
,
None
,
'Path to evaluation data for BERT classifier.'
)
flags
.
DEFINE_string
(
'input_meta_data_path'
,
None
,
'Path to file that contains meta data about input '
'to be used for training and evaluation.'
)
flags
.
DEFINE_integer
(
'train_data_size'
,
None
,
'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).'
)
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
'Path to the checkpoint for predictions.'
)
flags
.
DEFINE_integer
(
'num_eval_per_epoch'
,
1
,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
common_flags
.
define_common_bert_flags
()
FLAGS
=
flags
.
FLAGS
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
def
get_loss_fn
(
num_classes
):
"""Gets the classification loss function."""
def
classification_loss_fn
(
labels
,
logits
):
"""Classification loss."""
labels
=
tf
.
reshape
(
labels
,
[
-
1
])
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
one_hot_labels
=
tf
.
one_hot
(
tf
.
cast
(
labels
,
dtype
=
tf
.
int32
),
depth
=
num_classes
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
tf
.
cast
(
one_hot_labels
,
dtype
=
tf
.
float32
)
*
log_probs
,
axis
=-
1
)
return
tf
.
reduce_mean
(
per_example_loss
)
return
classification_loss_fn
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
is_training
,
label_type
=
tf
.
int64
,
include_sample_weights
=
False
,
num_samples
=
None
):
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size
=
ctx
.
get_per_replica_batch_size
(
global_batch_size
)
if
ctx
else
global_batch_size
dataset
=
input_pipeline
.
create_classifier_dataset
(
tf
.
io
.
gfile
.
glob
(
input_file_pattern
),
max_seq_length
,
batch_size
,
is_training
=
is_training
,
input_pipeline_context
=
ctx
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
,
num_samples
=
num_samples
)
return
dataset
return
_dataset_fn
def
run_bert_classifier
(
strategy
,
bert_config
,
input_meta_data
,
model_dir
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
warmup_steps
,
initial_lr
,
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
training_callbacks
=
True
,
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
.
get
(
'num_labels'
,
1
)
is_regression
=
num_classes
==
1
def
_get_classifier_model
():
"""Gets a classifier model."""
classifier_model
,
core_model
=
(
bert_models
.
classifier_model
(
bert_config
,
num_classes
,
max_seq_length
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
)
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
common_flags
.
use_float16
())
return
classifier_model
,
core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
# from the dataset) to compute weighted loss, as used for the regression
# tasks. The classification tasks, using the custom get_loss_fn don't accept
# sample weights though.
loss_fn
=
(
tf
.
keras
.
losses
.
MeanSquaredError
()
if
is_regression
else
get_loss_fn
(
num_classes
))
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
if
custom_metrics
:
metric_fn
=
custom_metrics
elif
is_regression
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
MeanSquaredError
,
'mean_squared_error'
,
dtype
=
tf
.
float32
)
else
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
,
'accuracy'
,
dtype
=
tf
.
float32
)
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
'distribution strategy.'
)
return
run_keras_compile_fit
(
model_dir
,
strategy
,
_get_classifier_model
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
training_callbacks
=
training_callbacks
,
custom_callbacks
=
custom_callbacks
)
def
run_keras_compile_fit
(
model_dir
,
strategy
,
model_fn
,
train_input_fn
,
eval_input_fn
,
loss_fn
,
metric_fn
,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
training_callbacks
=
True
,
custom_callbacks
=
None
):
"""Runs BERT classifier model using Keras compile/fit API."""
with
strategy
.
scope
():
training_dataset
=
train_input_fn
()
evaluation_dataset
=
eval_input_fn
()
if
eval_input_fn
else
None
bert_model
,
sub_model
=
model_fn
()
optimizer
=
bert_model
.
optimizer
if
init_checkpoint
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
,
encoder
=
sub_model
)
checkpoint
.
read
(
init_checkpoint
).
assert_existing_objects_matched
()
if
not
isinstance
(
metric_fn
,
(
list
,
tuple
)):
metric_fn
=
[
metric_fn
]
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
fn
()
for
fn
in
metric_fn
],
steps_per_execution
=
steps_per_loop
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
bert_model
,
optimizer
=
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
None
,
step_counter
=
optimizer
.
iterations
,
checkpoint_interval
=
0
)
checkpoint_callback
=
keras_utils
.
SimpleCheckpoint
(
checkpoint_manager
)
if
training_callbacks
:
if
custom_callbacks
is
not
None
:
custom_callbacks
+=
[
summary_callback
,
checkpoint_callback
]
else
:
custom_callbacks
=
[
summary_callback
,
checkpoint_callback
]
history
=
bert_model
.
fit
(
x
=
training_dataset
,
validation_data
=
evaluation_dataset
,
steps_per_epoch
=
steps_per_epoch
,
epochs
=
epochs
,
validation_steps
=
eval_steps
,
callbacks
=
custom_callbacks
)
stats
=
{
'total_training_steps'
:
steps_per_epoch
*
epochs
}
if
'loss'
in
history
.
history
:
stats
[
'train_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
if
'val_accuracy'
in
history
.
history
:
stats
[
'eval_metrics'
]
=
history
.
history
[
'val_accuracy'
][
-
1
]
return
bert_model
,
stats
def
get_predictions_and_labels
(
strategy
,
trained_model
,
eval_input_fn
,
is_regression
=
False
,
return_probs
=
False
):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
is_regression: Whether it is a regression task.
return_probs: Whether to return probabilities of classes.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@
tf
.
function
def
test_step
(
iterator
):
"""Computes predictions on distributed devices."""
def
_test_step_fn
(
inputs
):
"""Replicated predictions."""
inputs
,
labels
=
inputs
logits
=
trained_model
(
inputs
,
training
=
False
)
if
not
is_regression
:
probabilities
=
tf
.
nn
.
softmax
(
logits
)
return
probabilities
,
labels
else
:
return
logits
,
labels
outputs
,
labels
=
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
iterator
),))
# outputs: current batch logits as a tuple of shard logits
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
outputs
)
labels
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
labels
)
return
outputs
,
labels
def
_run_evaluation
(
test_iterator
):
"""Runs evaluation steps."""
preds
,
golds
=
list
(),
list
()
try
:
with
tf
.
experimental
.
async_scope
():
while
True
:
probabilities
,
labels
=
test_step
(
test_iterator
)
for
cur_probs
,
cur_labels
in
zip
(
probabilities
,
labels
):
if
return_probs
:
preds
.
extend
(
cur_probs
.
numpy
().
tolist
())
else
:
preds
.
extend
(
tf
.
math
.
argmax
(
cur_probs
,
axis
=
1
).
numpy
())
golds
.
extend
(
cur_labels
.
numpy
().
tolist
())
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
tf
.
experimental
.
async_clear_error
()
return
preds
,
golds
test_iter
=
iter
(
strategy
.
distribute_datasets_from_function
(
eval_input_fn
))
predictions
,
labels
=
_run_evaluation
(
test_iter
)
return
predictions
,
labels
def
export_classifier
(
model_export_path
,
input_meta_data
,
bert_config
,
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises:
Export path is not specified, got an empty string or None.
"""
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
if
not
model_dir
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_dir
)
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
False
)[
0
]
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
def
run_bert
(
strategy
,
input_meta_data
,
model_config
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
init_checkpoint
=
None
,
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
train_data_size
=
(
input_meta_data
[
'train_data_size'
]
//
FLAGS
.
num_eval_per_epoch
)
if
FLAGS
.
train_data_size
:
train_data_size
=
min
(
train_data_size
,
FLAGS
.
train_data_size
)
logging
.
info
(
'Updated train_data_size: %s'
,
train_data_size
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
train_data_size
*
0.1
/
FLAGS
.
train_batch_size
)
eval_steps
=
int
(
math
.
ceil
(
input_meta_data
[
'eval_data_size'
]
/
FLAGS
.
eval_batch_size
))
if
not
strategy
:
raise
ValueError
(
'Distribution strategy has not been specified.'
)
if
not
custom_callbacks
:
custom_callbacks
=
[]
if
FLAGS
.
log_steps
:
custom_callbacks
.
append
(
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
))
trained_model
,
_
=
run_bert_classifier
(
strategy
,
model_config
,
input_meta_data
,
FLAGS
.
model_dir
,
epochs
,
steps_per_epoch
,
FLAGS
.
steps_per_loop
,
eval_steps
,
warmup_steps
,
FLAGS
.
learning_rate
,
init_checkpoint
or
FLAGS
.
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
if
FLAGS
.
model_export_path
:
model_saving_utils
.
export_bert_model
(
FLAGS
.
model_export_path
,
model
=
trained_model
)
return
trained_model
def
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
):
"""Run classification or regression.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
"""
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
label_type
=
LABEL_TYPES_MAP
[
input_meta_data
.
get
(
'label_type'
,
'int'
)]
include_sample_weights
=
input_meta_data
.
get
(
'has_sample_weights'
,
False
)
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
bert_config
,
FLAGS
.
model_dir
)
return
strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
tpu_address
=
FLAGS
.
tpu
)
eval_input_fn
=
get_dataset_fn
(
FLAGS
.
eval_data_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
eval_batch_size
,
is_training
=
False
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
)
if
FLAGS
.
mode
==
'predict'
:
num_labels
=
input_meta_data
.
get
(
'num_labels'
,
1
)
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
num_labels
)[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
_
=
get_predictions_and_labels
(
strategy
,
classifier_model
,
eval_input_fn
,
is_regression
=
(
num_labels
==
1
),
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
train_input_fn
=
get_dataset_fn
(
FLAGS
.
train_data_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
train_batch_size
,
is_training
=
True
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
,
num_samples
=
FLAGS
.
train_data_size
)
run_bert
(
strategy
,
input_meta_data
,
bert_config
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
custom_callbacks
,
custom_metrics
=
custom_metrics
)
def
main
(
_
):
custom_main
(
custom_callbacks
=
None
,
custom_metrics
=
None
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'input_meta_data_path'
)
flags
.
mark_flag_as_required
(
'model_dir'
)
app
.
run
(
main
)
official/nlp/bert/run_squad.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
import
json
import
os
import
time
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.nlp.bert
import
configs
as
bert_configs
from
official.nlp.bert
import
run_squad_helper
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.utils.misc
import
keras_utils
flags
.
DEFINE_string
(
'vocab_file'
,
None
,
'The vocabulary file that the BERT model was trained on.'
)
# More flags can be found in run_squad_helper.
run_squad_helper
.
define_common_squad_flags
()
FLAGS
=
flags
.
FLAGS
def
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
None
,
run_eagerly
=
False
,
init_checkpoint
=
None
,
sub_model_export_name
=
None
):
"""Run bert squad training."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
init_checkpoint
=
init_checkpoint
or
FLAGS
.
init_checkpoint
run_squad_helper
.
train_squad
(
strategy
,
input_meta_data
,
bert_config
,
custom_callbacks
,
run_eagerly
,
init_checkpoint
,
sub_model_export_name
=
sub_model_export_name
)
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for the squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
run_squad_helper
.
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
def
eval_squad
(
strategy
,
input_meta_data
):
"""Evaluate on the squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
eval_metrics
=
run_squad_helper
.
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
run_squad_helper
.
export_squad
(
model_export_path
,
input_meta_data
,
bert_config
)
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_param
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
if
FLAGS
.
mode
==
'export_only'
:
export_squad
(
FLAGS
.
model_export_path
,
input_meta_data
)
return
# Configures cluster spec for multi-worker distribution strategy.
if
FLAGS
.
num_gpus
>
0
:
_
=
distribute_utils
.
configure_cluster
(
FLAGS
.
worker_hosts
,
FLAGS
.
task_index
)
strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
num_gpus
=
FLAGS
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
'train'
in
FLAGS
.
mode
:
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
log_steps
=
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
,
)]
else
:
custom_callbacks
=
None
train_squad
(
strategy
,
input_meta_data
,
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
FLAGS
.
run_eagerly
,
sub_model_export_name
=
FLAGS
.
sub_model_export_name
,
)
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
f1_score
=
eval_metrics
[
'final_f1'
]
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
,
'eval'
)
summary_writer
=
tf
.
summary
.
create_file_writer
(
summary_dir
)
with
summary_writer
.
as_default
():
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
# Also write eval_metrics to json file.
squad_lib_wp
.
write_to_json_files
(
eval_metrics
,
os
.
path
.
join
(
summary_dir
,
'eval_metrics.json'
))
time
.
sleep
(
60
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'model_dir'
)
app
.
run
(
main
)
official/nlp/bert/tf1_checkpoint_converter_lib.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
# TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS
=
(
(
"bert"
,
"bert_model"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"embedding_postprocessor/type_embeddings"
),
(
"embeddings/position_embeddings"
,
"embedding_postprocessor/position_embeddings"
),
(
"embeddings/LayerNorm"
,
"embedding_postprocessor/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
)
BERT_V2_NAME_REPLACEMENTS
=
(
(
"bert/"
,
""
),
(
"encoder"
,
"transformer"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"type_embeddings/embeddings"
),
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention/attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"cls/predictions"
,
"bert/cls/predictions"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
(
"cls/seq_relationship/output_weights"
,
"predictions/transform/logits/kernel"
),
)
BERT_PERMUTATIONS
=
()
BERT_V2_PERMUTATIONS
=
((
"cls/seq_relationship/output_weights"
,
(
1
,
0
)),)
def
_bert_name_replacement
(
var_name
,
name_replacements
):
"""Gets the variable name replacement."""
for
src_pattern
,
tgt_pattern
in
name_replacements
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
tf
.
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
_get_permutation
(
name
,
permutations
):
"""Checks whether a variable requires transposition by pattern matching."""
for
src_pattern
,
permutation
in
permutations
:
if
src_pattern
in
name
:
tf
.
logging
.
info
(
"Permuted: %s --> %s"
,
name
,
permutation
)
return
permutation
return
None
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
if
"self_attention/attention_output/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"self_attention/attention_output/bias"
in
name
:
return
shape
patterns
=
[
"self_attention/query"
,
"self_attention/value"
,
"self_attention/key"
]
for
pattern
in
patterns
:
if
pattern
in
name
:
if
"kernel"
in
name
:
return
tuple
([
shape
[
0
],
num_heads
,
shape
[
1
]
//
num_heads
])
if
"bias"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
])
return
None
def
create_v2_checkpoint
(
model
,
src_checkpoint
,
output_path
,
checkpoint_model_name
=
"model"
):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model
.
load_weights
(
src_checkpoint
).
assert_existing_objects_matched
()
if
hasattr
(
model
,
"checkpoint_items"
):
checkpoint_items
=
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint_items
[
checkpoint_model_name
]
=
model
checkpoint
=
tf
.
train
.
Checkpoint
(
**
checkpoint_items
)
checkpoint
.
save
(
output_path
)
def
convert
(
checkpoint_from_path
,
checkpoint_to_path
,
num_heads
,
name_replacements
,
permutations
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with
tf
.
Graph
().
as_default
():
tf
.
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
# Get the original tensor data.
tensor
=
reader
.
get_tensor
(
var_name
)
# Look up the new variable name, if any.
new_var_name
=
_bert_name_replacement
(
var_name
,
name_replacements
)
# See if we need to reshape the underlying tensor.
new_shape
=
None
if
num_heads
>
0
:
new_shape
=
_get_new_shape
(
new_var_name
,
tensor
.
shape
,
num_heads
)
if
new_shape
:
tf
.
logging
.
info
(
"Veriable %s has a shape change from %s to %s"
,
var_name
,
tensor
.
shape
,
new_shape
)
tensor
=
np
.
reshape
(
tensor
,
new_shape
)
# See if we need to permute the underlying tensor.
permutation
=
_get_permutation
(
var_name
,
permutations
)
if
permutation
:
tensor
=
np
.
transpose
(
tensor
,
permutation
)
# Create a new variable with the possibly-reshaped or transposed tensor.
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
# Save the variable into the new variable map.
new_variable_map
[
new_var_name
]
=
var
# Keep a list of converter variables for sanity checking.
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
logging
.
info
(
"Writing checkpoint_to_path %s"
,
checkpoint_to_path
)
saver
.
save
(
sess
,
checkpoint_to_path
,
write_meta_graph
=
False
)
tf
.
logging
.
info
(
"Summary:"
)
tf
.
logging
.
info
(
" Converted %d variable name(s)."
,
len
(
new_variable_map
))
tf
.
logging
.
info
(
" Converted: %s"
,
str
(
conversion_map
))
official/nlp/bert/tf2_encoder_checkpoint_converter.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
"""
import
os
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
"checkpoint_to_convert"
,
None
,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)"
)
flags
.
DEFINE_string
(
"converted_checkpoint_path"
,
None
,
"Name for the created object-based V2 checkpoint."
)
flags
.
DEFINE_string
(
"checkpoint_model_name"
,
"encoder"
,
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)."
)
flags
.
DEFINE_enum
(
"converted_model"
,
"encoder"
,
[
"encoder"
,
"pretrainer"
],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads)."
)
def
_create_bert_model
(
cfg
):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertEncoder network.
"""
bert_encoder
=
networks
.
BertEncoder
(
vocab_size
=
cfg
.
vocab_size
,
hidden_size
=
cfg
.
hidden_size
,
num_layers
=
cfg
.
num_hidden_layers
,
num_attention_heads
=
cfg
.
num_attention_heads
,
intermediate_size
=
cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
dropout_rate
=
cfg
.
hidden_dropout_prob
,
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
max_sequence_length
=
cfg
.
max_position_embeddings
,
type_vocab_size
=
cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
),
embedding_width
=
cfg
.
embedding_size
)
return
bert_encoder
def
_create_bert_pretrainer_model
(
cfg
):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder
=
_create_bert_model
(
cfg
)
pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
bert_encoder
,
mlm_activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
))
# Makes sure the pretrainer variables are created.
_
=
pretrainer
(
pretrainer
.
inputs
)
return
pretrainer
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
=
"model"
,
converted_model
=
"encoder"
):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir
=
os
.
path
.
join
(
output_dir
,
"temp_v1"
)
temporary_checkpoint
=
os
.
path
.
join
(
temporary_checkpoint_dir
,
"ckpt"
)
tf1_checkpoint_converter_lib
.
convert
(
checkpoint_from_path
=
v1_checkpoint
,
checkpoint_to_path
=
temporary_checkpoint
,
num_heads
=
bert_config
.
num_attention_heads
,
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_V2_NAME_REPLACEMENTS
,
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
,
exclude_patterns
=
[
"adam"
,
"Adam"
])
if
converted_model
==
"encoder"
:
model
=
_create_bert_model
(
bert_config
)
elif
converted_model
==
"pretrainer"
:
model
=
_create_bert_pretrainer_model
(
bert_config
)
else
:
raise
ValueError
(
"Unsupported converted_model: %s"
%
converted_model
)
# Create a V2 checkpoint from the temporary checkpoint.
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
output_path
,
checkpoint_model_name
)
# Clean up the temporary checkpoint, if it exists.
try
:
tf
.
io
.
gfile
.
rmtree
(
temporary_checkpoint_dir
)
except
tf
.
errors
.
OpError
:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
output_path
=
FLAGS
.
converted_checkpoint_path
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
checkpoint_model_name
=
FLAGS
.
checkpoint_model_name
converted_model
=
FLAGS
.
converted_model
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
convert_checkpoint
(
bert_config
=
bert_config
,
output_path
=
output_path
,
v1_checkpoint
=
v1_checkpoint
,
checkpoint_model_name
=
checkpoint_model_name
,
converted_model
=
converted_model
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/nlp/configs/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/configs/bert.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -41,3 +41,5 @@ class PretrainerConfig(base_config.Config):
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
mlm_activation
:
str
=
"gelu"
mlm_initializer_range
:
float
=
0.02
# Currently only used for mobile bert.
mlm_output_weights_use_proj
:
bool
=
False
official/nlp/configs/electra.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/configs/encoders.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -16,9 +16,9 @@
Includes configurations and factory methods.
"""
from
typing
import
Optional
import
dataclasses
from
typing
import
Optional
,
Sequence
import
gin
import
tensorflow
as
tf
...
...
@@ -26,7 +26,7 @@ from official.modeling import hyperparams
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.
nlp.
projects.bigbird
import
encoder
as
bigbird_encoder
from
official.projects.bigbird
import
encoder
as
bigbird_encoder
@
dataclasses
.
dataclass
...
...
@@ -221,6 +221,50 @@ class XLNetEncoderConfig(hyperparams.Config):
two_stream
:
bool
=
False
@
dataclasses
.
dataclass
class
QueryBertConfig
(
hyperparams
.
Config
):
"""Query BERT encoder configuration."""
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
output_range
:
Optional
[
int
]
=
None
return_all_encoder_outputs
:
bool
=
False
# Pre/Post-LN Transformer
norm_first
:
bool
=
False
@
dataclasses
.
dataclass
class
FNetEncoderConfig
(
hyperparams
.
Config
):
"""FNet encoder configuration."""
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
inner_activation
:
str
=
"gelu"
inner_dim
:
int
=
3072
output_dropout
:
float
=
0.1
attention_dropout
:
float
=
0.1
max_sequence_length
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_width
:
Optional
[
int
]
=
None
output_range
:
Optional
[
int
]
=
None
return_all_encoder_outputs
:
bool
=
False
# Pre/Post-LN Transformer
norm_first
:
bool
=
False
use_fft
:
bool
=
False
attention_layers
:
Sequence
[
int
]
=
()
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
...
...
@@ -233,6 +277,8 @@ class EncoderConfig(hyperparams.OneOfConfig):
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
reuse
:
ReuseEncoderConfig
=
ReuseEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
query_bert
:
QueryBertConfig
=
QueryBertConfig
()
fnet
:
FNetEncoderConfig
=
FNetEncoderConfig
()
# If `any` is used, the encoder building relies on any.BUILDER.
any
:
hyperparams
.
Config
=
hyperparams
.
Config
()
...
...
@@ -513,6 +559,54 @@ def build_encoder(config: EncoderConfig,
recursive
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"query_bert"
:
embedding_layer
=
layers
.
FactorizedEmbedding
(
vocab_size
=
encoder_cfg
.
vocab_size
,
embedding_width
=
encoder_cfg
.
embedding_size
,
output_dim
=
encoder_cfg
.
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
name
=
"word_embeddings"
)
return
networks
.
BertEncoderV2
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_layer
=
embedding_layer
,
return_all_encoder_outputs
=
encoder_cfg
.
return_all_encoder_outputs
,
dict_outputs
=
True
,
norm_first
=
encoder_cfg
.
norm_first
)
if
encoder_type
==
"fnet"
:
return
networks
.
FNet
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
inner_dim
=
encoder_cfg
.
inner_dim
,
inner_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
inner_activation
),
output_dropout
=
encoder_cfg
.
output_dropout
,
attention_dropout
=
encoder_cfg
.
attention_dropout
,
max_sequence_length
=
encoder_cfg
.
max_sequence_length
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_width
,
embedding_layer
=
embedding_layer
,
norm_first
=
encoder_cfg
.
norm_first
,
use_fft
=
encoder_cfg
.
use_fft
,
attention_layers
=
encoder_cfg
.
attention_layers
)
bert_encoder_cls
=
networks
.
BertEncoder
if
encoder_type
==
"bert_v2"
:
bert_encoder_cls
=
networks
.
BertEncoderV2
...
...
official/nlp/configs/encoders_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,7 +20,7 @@ import tensorflow as tf
from
official.modeling
import
hyperparams
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
networks
from
official.
nlp.
projects.teams
import
teams
from
official.projects.teams
import
teams
class
EncodersTest
(
tf
.
test
.
TestCase
):
...
...
official/nlp/configs/experiment_configs.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -17,4 +17,3 @@
from
official.nlp.configs
import
finetuning_experiments
from
official.nlp.configs
import
pretraining_experiments
from
official.nlp.configs
import
wmt_transformer_experiments
from
official.nlp.projects.teams
import
teams_experiments
official/nlp/configs/experiments/wiki_books_pretrain.yaml
0 → 100644
View file @
32e4ca51
task
:
init_checkpoint
:
'
'
model
:
cls_heads
:
[{
activation
:
tanh
,
cls_token_idx
:
0
,
dropout_rate
:
0.1
,
inner_dim
:
768
,
name
:
next_sentence
,
num_classes
:
2
}]
train_data
:
drop_remainder
:
true
global_batch_size
:
512
input_path
:
'
[Your
proceed
wiki
data
path]*,[Your
proceed
books
data
path]*'
is_training
:
true
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
use_v2_feature_names
:
true
validation_data
:
drop_remainder
:
false
global_batch_size
:
512
input_path
:
'
[Your
proceed
wiki
data
path]-00000-of-00500,[Your
proceed
books
data
path]-00000-of-00500'
is_training
:
false
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
use_v2_feature_names
:
true
trainer
:
checkpoint_interval
:
20000
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
cycle
:
false
decay_steps
:
1000000
end_learning_rate
:
0.0
initial_learning_rate
:
0.0001
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
10000
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
train_steps
:
1000000
validation_interval
:
1000
validation_steps
:
64
official/nlp/configs/finetuning_experiments.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/configs/pretraining_experiments.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/configs/wmt_transformer_experiments.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
...
...
official/nlp/continuous_finetune_lib.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/continuous_finetune_lib_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment