"tests/vscode:/vscode.git/clone" did not exist on "5704376d0309031a124fcb8a957fc70282ce13eb"
Commit 7687b1d3 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Open source mobilebert project.

PiperOrigin-RevId: 366313554
parent 41f71f6c
# MobileBERT (MobileBERT: A Compact Task-Agnostic BERT for Resource-Limited Devices)
[MobileBERT](https://arxiv.org/abs/2004.02984)
is a thin version of BERT_LARGE, while equipped with bottleneck
structures and a carefully designed balance between self-attentions and
feed-forward networks.
To train MobileBERT, we first train a specially designed teacher model, an
inverted-bottleneck incorporated BERT_LARGE model. Then, we conduct knowledge
transfer from this teacher to MobileBERT. Empirical studies show that MobileBERT
is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive
results on well-known benchmarks. This repository contains TensorFlow 2.x
implementation for MobileBERT.
## Network Implementations
Following
[MobileBERT TF1 implementation](https://github.com/google-research/google-research/tree/master/mobilebert),
we re-implemented MobileBERT encoder and layers using `tf.keras` APIs in NLP
modeling library:
* [mobile_bert_encoder.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py)
contains `MobileBERTEncoder` implementation.
* [mobile_bert_layers.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py)
contains `MobileBertEmbedding`, `MobileBertMaskedLM` and `MobileBertMaskedLM`
implementation.
## Pre-trained Models
We converted the originial TF 1.x pretrained English MobileBERT checkpoint to
TF 2.x checkpoint, which is compatible with the above implementations.
In addition, we also provide new multiple-lingual MobileBERT checkpoint
trained using multi-lingual Wiki data. Furthermore, we export the checkpoints to
TF-HUB SavedModel. Please find the details in the following table:
Model | Configuration | Number of Parameters | Training Data | Checkpoint & Vocabulary | TF-Hub SavedModel | Metrics
------------------------------ | :--------------------------------------: | :------------------- | :-----------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :-----:
MobileBERT uncased English | uncased_L-24_H-128_B-512_A-4_F-4_OPT | 25.3 Million | Wiki + Books | [Download](https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz) | [TF-Hub](https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1) | Squad v1.1 F1 90.0, GLUE 77.7
MobileBERT cased Multi-lingual | multi_cased_L-24_H-128_B-512_A-4_F-4_OPT | 36 Million | Wiki | [Download](https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/multi_cased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz) | [TF-Hub](https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1) | XNLI (zero-short):64.7
### Restoring from Checkpoints
To load the pre-trained MobileBERT checkpoint in your code, please follow the
example below:
```python
import tensorflow as tf
from official.nlp.projects.mobilebert import model_utils
bert_config_file = ...
model_checkpoint_path = ...
bert_config = model_utils.BertConfig.from_json_file(bert_config_file)
# `pretrainer` is an instance of `nlp.modeling.models.BertPretrainerV2`.
pretrainer = model_utils.create_mobilebert_pretrainer(bert_config)
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
# `mobilebert_encoder` is an instance of
# `nlp.modeling.networks.MobileBERTEncoder`.
mobilebert_encoder = pretrainer.encoder_network
```
### Use TF-Hub models
For the usage of MobileBert TF-Hub model, please see the TF-Hub site
([English model](https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1)
or
[Multilingual model](https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1)).
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This diff is collapsed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.mobilebert.distillation."""
import os
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling import tf_utils
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.modeling import models
from official.nlp.projects.mobilebert import distillation
class DistillationTest(tf.test.TestCase):
def setUp(self):
super(DistillationTest, self).setUp()
# using small model for testing
self.model_block_num = 2
self.task_config = distillation.BertDistillationTaskConfig(
teacher_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
],
mlm_activation='gelu'),
student_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
],
mlm_activation='relu'),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy',
max_predictions_per_seq=76,
seq_length=512,
global_batch_size=10),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy',
max_predictions_per_seq=76,
seq_length=512,
global_batch_size=10))
# set only 1 step for each stage
progressive_config = distillation.BertDistillationProgressiveConfig()
progressive_config.layer_wise_distill_config.num_steps = 1
progressive_config.pretrain_distill_config.num_steps = 1
optimization_config = optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(
type='lamb',
lamb=optimization.LAMBConfig(
weight_decay_rate=0.0001,
exclude_from_weight_decay=[
'LayerNorm', 'layer_norm', 'bias', 'no_norm'
])),
learning_rate=optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1.5e-3,
decay_steps=10000,
end_learning_rate=1.5e-3)),
warmup=optimization.WarmupConfig(
type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=progressive_config,
optimizer_config=optimization_config))
# Create a teacher model checkpoint.
teacher_encoder = encoders.build_encoder(
self.task_config.teacher_model.encoder)
pretrainer_config = self.task_config.teacher_model
if pretrainer_config.cls_heads:
teacher_cls_heads = [
layers.ClassificationHead(**cfg.as_dict())
for cfg in pretrainer_config.cls_heads
]
else:
teacher_cls_heads = []
masked_lm = layers.MobileBertMaskedLM(
embedding_table=teacher_encoder.get_embedding_table(),
activation=tf_utils.get_activation(pretrainer_config.mlm_activation),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=pretrainer_config.mlm_initializer_range),
name='cls/predictions')
teacher_pretrainer = models.BertPretrainerV2(
encoder_network=teacher_encoder,
classification_heads=teacher_cls_heads,
customized_masked_lm=masked_lm)
# The model variables will be created after the forward call.
_ = teacher_pretrainer(teacher_pretrainer.inputs)
teacher_pretrainer_ckpt = tf.train.Checkpoint(
**teacher_pretrainer.checkpoint_items)
teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt')
teacher_pretrainer_ckpt.save(teacher_ckpt_path)
self.task_config.teacher_model_init_checkpoint = self.get_temp_dir()
def test_task(self):
bert_distillation_task = distillation.BertDistillationTask(
strategy=tf.distribute.get_strategy(),
progressive=self.exp_config.trainer.progressive,
optimizer_config=self.exp_config.trainer.optimizer_config,
task_config=self.task_config)
metrics = bert_distillation_task.build_metrics()
train_dataset = bert_distillation_task.get_train_dataset(stage_id=0)
train_iterator = iter(train_dataset)
eval_dataset = bert_distillation_task.get_eval_dataset(stage_id=0)
eval_iterator = iter(eval_dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
# test train/val step for all stages, including the last pretraining stage
for stage in range(self.model_block_num + 1):
step = stage
bert_distillation_task.update_pt_stage(step)
model = bert_distillation_task.get_model(stage, None)
bert_distillation_task.initialize(model)
bert_distillation_task.train_step(next(train_iterator), model, optimizer,
metrics=metrics)
bert_distillation_task.validation_step(next(eval_iterator), model,
metrics=metrics)
logging.info('begin to save and load model checkpoint')
ckpt = tf.train.Checkpoint(model=model)
ckpt.save(self.get_temp_dir())
if __name__ == '__main__':
tf.test.main()
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 512
hidden_activation: relu
hidden_dropout_prob: 0.0
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 128
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 4096
hidden_activation: gelu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 1024
initializer_range: 0.02
key_query_shared_bottleneck: false
num_feedforward_networks: 1
normalization_type: layer_norm
classifier_activation: false
task:
train_data:
drop_remainder: true
global_batch_size: 2048
input_path: ""
is_training: true
max_predictions_per_seq: 20
seq_length: 512
use_next_sentence_label: true
use_position_id: false
validation_data:
drop_remainder: true
global_batch_size: 2048
input_path: ""
is_training: false
max_predictions_per_seq: 20
seq_length: 512
use_next_sentence_label: true
use_position_id: false
teacher_model:
cls_heads: []
mlm_activation: gelu
mlm_initializer_range: 0.02
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 4096
hidden_activation: gelu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 1024
initializer_range: 0.02
key_query_shared_bottleneck: false
num_feedforward_networks: 1
normalization_type: layer_norm
classifier_activation: false
student_model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.0, inner_dim: 512,
name: next_sentence, num_classes: 2}]
mlm_activation: relu
mlm_initializer_range: 0.02
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 512
hidden_activation: relu
hidden_dropout_prob: 0.0
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 128
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
teacher_model_init_checkpoint: ""
trainer:
progressive:
if_copy_embeddings: true
layer_wise_distill_config:
num_steps: 10000
pretrain_distill_config:
num_steps: 500000
decay_steps: 500000
train_steps: 740000
max_to_keep: 10
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export the MobileBERT encoder model as a TF-Hub SavedModel."""
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.projects.mobilebert import model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string(
"bert_config_file", None,
"Bert configuration file to define core mobilebert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", True, "Whether to lowercase.")
def create_mobilebert_model(bert_config):
"""Creates a model for exporting to tfhub."""
pretrainer = model_utils.create_mobilebert_pretrainer(bert_config)
encoder = pretrainer.encoder_network
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations,
# add "default" as an alias for MobileBERT's whole-input reptesentations.
encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
core_model = tf.keras.Model(
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
mlm_model = tf.keras.Model(
inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
# Set `_auto_track_sub_layers` to False, so that the additional weights
# from `mlm` sub-object will not be included in the core model.
# TODO(b/169210253): Use public API after the bug is resolved.
core_model._auto_track_sub_layers = False # pylint: disable=protected-access
core_model.mlm = mlm_model
return core_model, pretrainer
def export_bert_tfhub(bert_config, model_checkpoint_path, hub_destination,
vocab_file, do_lower_case):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, pretrainer = create_mobilebert_model(bert_config)
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
logging.info("Begin to load model")
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
logging.info("Loading model finished")
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
logging.info("Begin to save files for tfhub at %s", hub_destination)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
logging.info("tfhub files exported!")
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
bert_config = model_utils.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file, FLAGS.do_lower_case)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import copy
import json
import tensorflow.compat.v1 as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
from official.nlp.modeling import networks
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
embedding_size=None,
trigram_input=False,
use_bottleneck=False,
intra_bottleneck_size=None,
use_bottleneck_attention=False,
key_query_shared_bottleneck=False,
num_feedforward_networks=1,
normalization_type="layer_norm",
classifier_activation=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: The size of the token embedding.
trigram_input: Use a convolution of trigram as input.
use_bottleneck: Use the bottleneck/inverted-bottleneck structure in BERT.
intra_bottleneck_size: The hidden size in the bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation.
key_query_shared_bottleneck: Use the same linear transformation for
query&key in the bottleneck.
num_feedforward_networks: Number of FFNs in a block.
normalization_type: The normalization type in BERT.
classifier_activation: Using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.trigram_input = trigram_input
self.use_bottleneck = use_bottleneck
self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type
self.classifier_activation = classifier_activation
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in json_object.items():
config.__dict__[key] = value
if config.embedding_size is None:
config.embedding_size = config.hidden_size
if config.intra_bottleneck_size is None:
config.intra_bottleneck_size = config.hidden_size
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def create_mobilebert_pretrainer(bert_config):
"""Creates a BertPretrainerV2 that wraps MobileBERTEncoder model."""
mobilebert_encoder = networks.MobileBERTEncoder(
word_vocab_size=bert_config.vocab_size,
word_embed_size=bert_config.embedding_size,
type_vocab_size=bert_config.type_vocab_size,
max_sequence_length=bert_config.max_position_embeddings,
num_blocks=bert_config.num_hidden_layers,
hidden_size=bert_config.hidden_size,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_act_fn=tf_utils.get_activation(bert_config.hidden_act),
hidden_dropout_prob=bert_config.hidden_dropout_prob,
attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob,
intra_bottleneck_size=bert_config.intra_bottleneck_size,
initializer_range=bert_config.initializer_range,
use_bottleneck_attention=bert_config.use_bottleneck_attention,
key_query_shared_bottleneck=bert_config.key_query_shared_bottleneck,
num_feedforward_networks=bert_config.num_feedforward_networks,
normalization_type=bert_config.normalization_type,
classifier_activation=bert_config.classifier_activation)
masked_lm = layers.MobileBertMaskedLM(
embedding_table=mobilebert_encoder.get_embedding_table(),
activation=tf_utils.get_activation(bert_config.hidden_act),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
name="cls/predictions")
pretrainer = models.BertPretrainerV2(
encoder_network=mobilebert_encoder, customized_masked_lm=masked_lm)
# Makes sure the pretrainer variables are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
"""Creating the task and start trainer."""
import pprint
from absl import app
from absl import flags
from absl import logging
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import config_definitions as cfg
from official.core import train_utils
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling import performance
from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.mobilebert import distillation
FLAGS = flags.FLAGS
optimization_config = optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(
type='lamb',
lamb=optimization.LAMBConfig(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'bias', 'norm'],
clipnorm=1.0)),
learning_rate=optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1.5e-3,
decay_steps=10000,
end_learning_rate=1.5e-3)),
warmup=optimization.WarmupConfig(
type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
# copy from progressive/utils.py due to the private visibility issue.
def config_override(params, flags_obj):
"""Override ExperimentConfig according to flags."""
# Change runtime.tpu to the real tpu.
params.override({
'runtime': {
'tpu': flags_obj.tpu,
}
})
# Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for config_file in flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override:
params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True)
params.validate()
params.lock()
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict()))
model_dir = flags_obj.model_dir
if 'train' in flags_obj.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
return params
def get_exp_config():
"""Get ExperimentConfig."""
params = cfg.ExperimentConfig(
task=distillation.BertDistillationTaskConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=distillation.BertDistillationProgressiveConfig(),
optimizer_config=optimization_config,
train_steps=740000,
checkpoint_interval=20000))
return config_override(params, FLAGS)
def main(_):
logging.info('Parsing config files...')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = get_exp_config()
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = distillation.BertDistillationTask(
strategy=distribution_strategy,
progressive=params.trainer.progressive,
optimizer_config=params.trainer.optimizer_config,
task_config=params.task)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=FLAGS.model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from official.nlp.projects.mobilebert import model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string(
"bert_config_file", None,
"Bert configuration file to define core mobilebert layers.")
flags.DEFINE_string("tf1_checkpoint_path", None,
"Path to load tf1 checkpoint.")
flags.DEFINE_string("tf2_checkpoint_path", None,
"Path to save tf2 checkpoint.")
flags.DEFINE_boolean("use_model_prefix", False,
("If use model name as prefix for variables. Turn this"
"flag on when the converted checkpoint is used for model"
"in subclass implementation, which uses the model name as"
"prefix for all variable names."))
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "attention/attention_output/bias" in name:
return shape
patterns = [
"attention/query", "attention/value", "attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def convert(checkpoint_from_path,
checkpoint_to_path,
name_replacements,
permutations,
bert_config,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
bert_config: A `BertConfig` to create the core model.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
last_ffn_layer_id = str(bert_config.num_feedforward_networks - 1)
name_replacements = [
(x[0], x[1].replace("LAST_FFN_LAYER_ID", last_ffn_layer_id))
for x in name_replacements
]
output_dir, _ = os.path.split(checkpoint_to_path)
tf.io.gfile.makedirs(output_dir)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
with tf.Graph().as_default():
logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if bert_config.num_attention_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape,
bert_config.num_attention_heads)
if new_shape:
logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
logging.info("Writing checkpoint_to_path %s", temporary_checkpoint)
saver.save(sess, temporary_checkpoint, write_meta_graph=False)
logging.info("Summary:")
logging.info("Converted %d variable name(s).", len(new_variable_map))
logging.info("Converted: %s", str(conversion_map))
mobilebert_model = model_utils.create_mobilebert_pretrainer(bert_config)
create_v2_checkpoint(
mobilebert_model, temporary_checkpoint, checkpoint_to_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def create_v2_checkpoint(model, src_checkpoint, output_path):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(**model.checkpoint_items)
checkpoint.save(output_path)
_NAME_REPLACEMENT = [
# prefix path replacement
("bert/", "mobile_bert_encoder/"),
("encoder/layer_", "transformer_layer_"),
# embedding layer
("embeddings/embedding_transformation",
"mobile_bert_embedding/embedding_projection"),
("embeddings/position_embeddings",
"mobile_bert_embedding/position_embedding/embeddings"),
("embeddings/token_type_embeddings",
"mobile_bert_embedding/type_embedding/embeddings"),
("embeddings/word_embeddings",
"mobile_bert_embedding/word_embedding/embeddings"),
("embeddings/FakeLayerNorm", "mobile_bert_embedding/embedding_norm"),
("embeddings/LayerNorm", "mobile_bert_embedding/embedding_norm"),
# attention layer
("attention/output/dense", "attention/attention_output"),
("attention/output/FakeLayerNorm", "attention/norm"),
("attention/output/LayerNorm", "attention/norm"),
("attention/self", "attention"),
# input bottleneck
("bottleneck/input/dense", "bottleneck_input/dense"),
("bottleneck/input/FakeLayerNorm", "bottleneck_input/norm"),
("bottleneck/input/LayerNorm", "bottleneck_input/norm"),
("bottleneck/attention/dense", "kq_shared_bottleneck/dense"),
("bottleneck/attention/FakeLayerNorm", "kq_shared_bottleneck/norm"),
("bottleneck/attention/LayerNorm", "kq_shared_bottleneck/norm"),
# ffn layer
("ffn_layer_0/output/dense", "ffn_layer_0/output_dense"),
("ffn_layer_1/output/dense", "ffn_layer_1/output_dense"),
("ffn_layer_2/output/dense", "ffn_layer_2/output_dense"),
("output/dense", "ffn_layer_LAST_FFN_LAYER_ID/output_dense"),
("ffn_layer_0/output/FakeLayerNorm", "ffn_layer_0/norm"),
("ffn_layer_0/output/LayerNorm", "ffn_layer_0/norm"),
("ffn_layer_1/output/FakeLayerNorm", "ffn_layer_1/norm"),
("ffn_layer_1/output/LayerNorm", "ffn_layer_1/norm"),
("ffn_layer_2/output/FakeLayerNorm", "ffn_layer_2/norm"),
("ffn_layer_2/output/LayerNorm", "ffn_layer_2/norm"),
("output/FakeLayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
("output/LayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
("ffn_layer_0/intermediate/dense", "ffn_layer_0/intermediate_dense"),
("ffn_layer_1/intermediate/dense", "ffn_layer_1/intermediate_dense"),
("ffn_layer_2/intermediate/dense", "ffn_layer_2/intermediate_dense"),
("intermediate/dense", "ffn_layer_LAST_FFN_LAYER_ID/intermediate_dense"),
# output bottleneck
("output/bottleneck/FakeLayerNorm", "bottleneck_output/norm"),
("output/bottleneck/LayerNorm", "bottleneck_output/norm"),
("output/bottleneck/dense", "bottleneck_output/dense"),
# pooler layer
("pooler/dense", "pooler"),
# MLM layer
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias")
]
_EXCLUDE_PATTERNS = ["cls/seq_relationship", "global_step"]
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
if not FLAGS.use_model_prefix:
_NAME_REPLACEMENT[0] = ("bert/", "")
bert_config = model_utils.BertConfig.from_json_file(FLAGS.bert_config_file)
convert(FLAGS.tf1_checkpoint_path,
FLAGS.tf2_checkpoint_path,
_NAME_REPLACEMENT,
[],
bert_config,
_EXCLUDE_PATTERNS)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions."""
import numpy as np
def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
"""Generate consistent fake integer input sequences."""
np.random.seed(seed)
fake_input = []
for _ in range(batch_size):
fake_input.append([])
for _ in range(seq_len):
fake_input[-1].append(np.random.randint(0, vocab_size))
fake_input = np.asarray(fake_input)
return fake_input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment