Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defining common flags used across all BERT models/applications."""
from absl import flags
import tensorflow as tf
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
data_dir=False,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=False,
num_gpu=True,
export_dir=False,
distribution_strategy=True,
run_eagerly=True)
flags_core.define_distribution()
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', None,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside. If not set the value will be configured depending on the '
'devices available.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
flags.DEFINE_boolean(
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string(
'sub_model_export_name', None,
'If set, `sub_model` checkpoints are exported into '
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
flags.DEFINE_bool('explicit_allreduce', False,
'True to use explicit allreduce instead of the implicit '
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce '
'gradients in fp16.')
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True,
)
# Adds gin configuration flags.
hyperparams_flags.define_gin_flags()
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The main BERT model and related functions."""
import copy
import json
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
embedding_size=None,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: (Optional) width of the factorized word embeddings.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export BERT as a TF-Hub SavedModel.
This script is **DEPRECATED** for exporting BERT encoder models;
see the error message in by main() for details.
"""
from typing import Text
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.bert import bert_models
from official.nlp.bert import configs
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", None, "Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
"What kind of BERT model to export.")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A `BertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text,
hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(
model=encoder, # Legacy checkpoints.
encoder=encoder)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def export_bert_squad_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text,
hub_destination: Text,
vocab_file: Text,
do_lower_case: bool = None):
"""Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
checkpoint = tf.train.Checkpoint(model=span_labeling)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.model_type == "encoder":
deprecation_note = (
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
"(and other) encoders with dict inputs/outputs conforming to "
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
)
logging.error(deprecation_note)
print("\n\nNOTICE:", deprecation_note, "\n")
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
elif FLAGS.model_type == "squad":
export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file,
FLAGS.do_lower_case)
else:
raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests official.nlp.bert.export_tfhub."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters("model", "encoder")
def test_export_tfhub(self, ckpt_key_name):
# Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig(
vocab_size=100,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(**{ckpt_key_name: encoder})
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
# Checks meta attributes.
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
with tf.io.gfile.GFile(
hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
seq_length = 10
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
# Test that training=True makes a difference (activates dropout).
def _dropout_mean_stddev(training, num_runs=20):
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
inputs = [input_ids, np.ones_like(input_ids), np.zeros_like(input_ids)]
outputs = np.concatenate(
[hub_layer(inputs, training=training)[0] for _ in range(num_runs)])
return np.mean(np.std(outputs, axis=0))
self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
# Test propagation of seq_length in shape inference.
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
self.assertEqual(sequence_output.shape.as_list(),
[None, seq_length, hidden_size])
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT classification or regression finetuning runner in TF 2.x."""
import functools
import json
import math
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.')
flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_integer('train_data_size', None, 'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).')
flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.')
flags.DEFINE_integer(
'num_eval_per_epoch', 1,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
def get_loss_fn(num_classes):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
"""Classification loss."""
labels = tf.reshape(labels, [-1])
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
return tf.reduce_mean(per_example_loss)
return classification_loss_fn
def get_dataset_fn(input_file_pattern,
max_seq_length,
global_batch_size,
is_training,
label_type=tf.int64,
include_sample_weights=False,
num_samples=None):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset(
tf.io.gfile.glob(input_file_pattern),
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx,
label_type=label_type,
include_sample_weights=include_sample_weights,
num_samples=num_samples)
return dataset
return _dataset_fn
def run_bert_classifier(strategy,
bert_config,
input_meta_data,
model_dir,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
warmup_steps,
initial_lr,
init_checkpoint,
train_input_fn,
eval_input_fn,
training_callbacks=True,
custom_callbacks=None,
custom_metrics=None):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data.get('num_labels', 1)
is_regression = num_classes == 1
def _get_classifier_model():
"""Gets a classifier model."""
classifier_model, core_model = (
bert_models.classifier_model(
bert_config,
num_classes,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(initial_lr,
steps_per_epoch * epochs,
warmup_steps, FLAGS.end_lr,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16())
return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
# from the dataset) to compute weighted loss, as used for the regression
# tasks. The classification tasks, using the custom get_loss_fn don't accept
# sample weights though.
loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression
else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
if custom_metrics:
metric_fn = custom_metrics
elif is_regression:
metric_fn = functools.partial(
tf.keras.metrics.MeanSquaredError,
'mean_squared_error',
dtype=tf.float32)
else:
metric_fn = functools.partial(
tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy',
dtype=tf.float32)
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.x Keras compile/fit API with '
'distribution strategy.')
return run_keras_compile_fit(
model_dir,
strategy,
_get_classifier_model,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
training_callbacks=training_callbacks,
custom_callbacks=custom_callbacks)
def run_keras_compile_fit(model_dir,
strategy,
model_fn,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
training_callbacks=True,
custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API."""
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn() if eval_input_fn else None
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
checkpoint.read(init_checkpoint).assert_existing_objects_matched()
if not isinstance(metric_fn, (list, tuple)):
metric_fn = [metric_fn]
bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[fn() for fn in metric_fn],
steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=None,
step_counter=optimizer.iterations,
checkpoint_interval=0)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
if training_callbacks:
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
else:
custom_callbacks = [summary_callback, checkpoint_callback]
history = bert_model.fit(
x=training_dataset,
validation_data=evaluation_dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_steps=eval_steps,
callbacks=custom_callbacks)
stats = {'total_training_steps': steps_per_epoch * epochs}
if 'loss' in history.history:
stats['train_loss'] = history.history['loss'][-1]
if 'val_accuracy' in history.history:
stats['eval_metrics'] = history.history['val_accuracy'][-1]
return bert_model, stats
def get_predictions_and_labels(strategy,
trained_model,
eval_input_fn,
is_regression=False,
return_probs=False):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
is_regression: Whether it is a regression task.
return_probs: Whether to return probabilities of classes.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@tf.function
def test_step(iterator):
"""Computes predictions on distributed devices."""
def _test_step_fn(inputs):
"""Replicated predictions."""
inputs, labels = inputs
logits = trained_model(inputs, training=False)
if not is_regression:
probabilities = tf.nn.softmax(logits)
return probabilities, labels
else:
return logits, labels
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
return outputs, labels
def _run_evaluation(test_iterator):
"""Runs evaluation steps."""
preds, golds = list(), list()
try:
with tf.experimental.async_scope():
while True:
probabilities, labels = test_step(test_iterator)
for cur_probs, cur_labels in zip(probabilities, labels):
if return_probs:
preds.extend(cur_probs.numpy().tolist())
else:
preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return preds, golds
test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter)
return predictions, labels
def export_classifier(model_export_path, input_meta_data, bert_config,
model_dir):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
if not model_dir:
raise ValueError('Export path is not specified: %s' % model_dir)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.set_global_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config,
input_meta_data.get('num_labels', 1),
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=False)[0]
model_saving_utils.export_bert_model(
model_export_path, model=classifier_model, checkpoint_dir=model_dir)
def run_bert(strategy,
input_meta_data,
model_config,
train_input_fn=None,
eval_input_fn=None,
init_checkpoint=None,
custom_callbacks=None,
custom_metrics=None):
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = (
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
if FLAGS.train_data_size:
train_data_size = min(train_data_size, FLAGS.train_data_size)
logging.info('Updated train_data_size: %s', train_data_size)
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if not custom_callbacks:
custom_callbacks = []
if FLAGS.log_steps:
custom_callbacks.append(
keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir))
trained_model, _ = run_bert_classifier(
strategy,
model_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
FLAGS.steps_per_loop,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
init_checkpoint or FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
def custom_main(custom_callbacks=None, custom_metrics=None):
"""Run classification or regression.
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_metrics: list of metrics passed to the training loop.
"""
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
include_sample_weights = input_meta_data.get('has_sample_weights', False)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
FLAGS.model_dir)
return
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path,
input_meta_data['max_seq_length'],
FLAGS.eval_batch_size,
is_training=False,
label_type=label_type,
include_sample_weights=include_sample_weights)
if FLAGS.mode == 'predict':
num_labels = input_meta_data.get('num_labels', 1)
with strategy.scope():
classifier_model = bert_models.classifier_model(
bert_config, num_labels)[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, _ = get_predictions_and_labels(
strategy,
classifier_model,
eval_input_fn,
is_regression=(num_labels == 1),
return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
input_meta_data['max_seq_length'],
FLAGS.train_batch_size,
is_training=True,
label_type=label_type,
include_sample_weights=include_sample_weights,
num_samples=FLAGS.train_data_size)
run_bert(
strategy,
input_meta_data,
bert_config,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks,
custom_metrics=custom_metrics)
def main(_):
custom_main(custom_callbacks=None, custom_metrics=None)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
flags.mark_flag_as_required('model_dir')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
import json
import os
import time
# Import libraries
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.common import distribute_utils
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None,
sub_model_export_name=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly, init_checkpoint,
sub_model_export_name=sub_model_export_name)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if 'train' in FLAGS.mode:
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
sub_model_export_name=FLAGS.sub_model_export_name,
)
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
import numpy as np
import tensorflow.compat.v1 as tf # TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS = (
("bert", "bert_model"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings",
"embedding_postprocessor/type_embeddings"),
("embeddings/position_embeddings",
"embedding_postprocessor/position_embeddings"),
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
)
BERT_V2_NAME_REPLACEMENTS = (
("bert/", ""),
("encoder", "transformer"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention/attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
)
BERT_PERMUTATIONS = ()
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
tf.logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "self_attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention/attention_output/bias" in name:
return shape
patterns = [
"self_attention/query", "self_attention/value", "self_attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def create_v2_checkpoint(model,
src_checkpoint,
output_path,
checkpoint_model_name="model"):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = {}
checkpoint_items[checkpoint_model_name] = model
checkpoint = tf.train.Checkpoint(**checkpoint_items)
checkpoint.save(output_path)
def convert(checkpoint_from_path,
checkpoint_to_path,
num_heads,
name_replacements,
permutations,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with tf.Graph().as_default():
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if num_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
saver.save(sess, checkpoint_to_path, write_meta_graph=False)
tf.logging.info("Summary:")
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
tf.logging.info(" Converted: %s", str(conversion_map))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
"""
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import models
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string(
"checkpoint_to_convert", None,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
flags.DEFINE_enum(
"converted_model", "encoder", ["encoder", "pretrainer"],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads).")
def _create_bert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertEncoder network.
"""
bert_encoder = networks.BertEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=tf_utils.get_activation(cfg.hidden_act),
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range),
embedding_width=cfg.embedding_size)
return bert_encoder
def _create_bert_pretrainer_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder = _create_bert_model(cfg)
pretrainer = models.BertPretrainerV2(
encoder_network=bert_encoder,
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
# Makes sure the pretrainer variables are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
def convert_checkpoint(bert_config,
output_path,
v1_checkpoint,
checkpoint_model_name="model",
converted_model="encoder"):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads,
name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"])
if converted_model == "encoder":
model = _create_bert_model(bert_config)
elif converted_model == "pretrainer":
model = _create_bert_pretrainer_model(bert_config)
else:
raise ValueError("Unsupported converted_model: %s" % converted_model)
# Create a V2 checkpoint from the temporary checkpoint.
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path,
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
converted_model = FLAGS.converted_model
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(
bert_config=bert_config,
output_path=output_path,
v1_checkpoint=v1_checkpoint,
checkpoint_model_name=checkpoint_model_name,
converted_model=converted_model)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -41,3 +41,5 @@ class PretrainerConfig(base_config.Config):
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
mlm_activation: str = "gelu"
mlm_initializer_range: float = 0.02
# Currently only used for mobile bert.
mlm_output_weights_use_proj: bool = False
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,9 +16,9 @@
Includes configurations and factory methods.
"""
from typing import Optional
import dataclasses
from typing import Optional, Sequence
import gin
import tensorflow as tf
......@@ -26,7 +26,7 @@ from official.modeling import hyperparams
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.projects.bigbird import encoder as bigbird_encoder
from official.projects.bigbird import encoder as bigbird_encoder
@dataclasses.dataclass
......@@ -221,6 +221,50 @@ class XLNetEncoderConfig(hyperparams.Config):
two_stream: bool = False
@dataclasses.dataclass
class QueryBertConfig(hyperparams.Config):
"""Query BERT encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass
class FNetEncoderConfig(hyperparams.Config):
"""FNet encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
inner_activation: str = "gelu"
inner_dim: int = 3072
output_dropout: float = 0.1
attention_dropout: float = 0.1
max_sequence_length: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_width: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
use_fft: bool = False
attention_layers: Sequence[int] = ()
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
......@@ -233,6 +277,8 @@ class EncoderConfig(hyperparams.OneOfConfig):
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
query_bert: QueryBertConfig = QueryBertConfig()
fnet: FNetEncoderConfig = FNetEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
......@@ -513,6 +559,54 @@ def build_encoder(config: EncoderConfig,
recursive=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "query_bert":
embedding_layer = layers.FactorizedEmbedding(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_size,
output_dim=encoder_cfg.hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
name="word_embeddings")
return networks.BertEncoderV2(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True,
norm_first=encoder_cfg.norm_first)
if encoder_type == "fnet":
return networks.FNet(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
embedding_layer=embedding_layer,
norm_first=encoder_cfg.norm_first,
use_fft=encoder_cfg.use_fft,
attention_layers=encoder_cfg.attention_layers)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,7 @@ import tensorflow as tf
from official.modeling import hyperparams
from official.nlp.configs import encoders
from official.nlp.modeling import networks
from official.nlp.projects.teams import teams
from official.projects.teams import teams
class EncodersTest(tf.test.TestCase):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,4 +17,3 @@
from official.nlp.configs import finetuning_experiments
from official.nlp.configs import pretraining_experiments
from official.nlp.configs import wmt_transformer_experiments
from official.nlp.projects.teams import teams_experiments
task:
init_checkpoint: ''
model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
train_data:
drop_remainder: true
global_batch_size: 512
input_path: '[Your proceed wiki data path]*,[Your proceed books data path]*'
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
validation_data:
drop_remainder: false
global_batch_size: 512
input_path: '[Your proceed wiki data path]-00000-of-00500,[Your proceed books data path]-00000-of-00500'
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment