Commit 7687b1d3 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Open source mobilebert project.

PiperOrigin-RevId: 366313554
parent 41f71f6c
# MobileBERT (MobileBERT: A Compact Task-Agnostic BERT for Resource-Limited Devices)
[MobileBERT](https://arxiv.org/abs/2004.02984)
is a thin version of BERT_LARGE, while equipped with bottleneck
structures and a carefully designed balance between self-attentions and
feed-forward networks.
To train MobileBERT, we first train a specially designed teacher model, an
inverted-bottleneck incorporated BERT_LARGE model. Then, we conduct knowledge
transfer from this teacher to MobileBERT. Empirical studies show that MobileBERT
is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive
results on well-known benchmarks. This repository contains TensorFlow 2.x
implementation for MobileBERT.
## Network Implementations
Following
[MobileBERT TF1 implementation](https://github.com/google-research/google-research/tree/master/mobilebert),
we re-implemented MobileBERT encoder and layers using `tf.keras` APIs in NLP
modeling library:
* [mobile_bert_encoder.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py)
contains `MobileBERTEncoder` implementation.
* [mobile_bert_layers.py](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py)
contains `MobileBertEmbedding`, `MobileBertMaskedLM` and `MobileBertMaskedLM`
implementation.
## Pre-trained Models
We converted the originial TF 1.x pretrained English MobileBERT checkpoint to
TF 2.x checkpoint, which is compatible with the above implementations.
In addition, we also provide new multiple-lingual MobileBERT checkpoint
trained using multi-lingual Wiki data. Furthermore, we export the checkpoints to
TF-HUB SavedModel. Please find the details in the following table:
Model | Configuration | Number of Parameters | Training Data | Checkpoint & Vocabulary | TF-Hub SavedModel | Metrics
------------------------------ | :--------------------------------------: | :------------------- | :-----------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :-----:
MobileBERT uncased English | uncased_L-24_H-128_B-512_A-4_F-4_OPT | 25.3 Million | Wiki + Books | [Download](https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz) | [TF-Hub](https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1) | Squad v1.1 F1 90.0, GLUE 77.7
MobileBERT cased Multi-lingual | multi_cased_L-24_H-128_B-512_A-4_F-4_OPT | 36 Million | Wiki | [Download](https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/multi_cased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz) | [TF-Hub](https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1) | XNLI (zero-short):64.7
### Restoring from Checkpoints
To load the pre-trained MobileBERT checkpoint in your code, please follow the
example below:
```python
import tensorflow as tf
from official.nlp.projects.mobilebert import model_utils
bert_config_file = ...
model_checkpoint_path = ...
bert_config = model_utils.BertConfig.from_json_file(bert_config_file)
# `pretrainer` is an instance of `nlp.modeling.models.BertPretrainerV2`.
pretrainer = model_utils.create_mobilebert_pretrainer(bert_config)
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
# `mobilebert_encoder` is an instance of
# `nlp.modeling.networks.MobileBERTEncoder`.
mobilebert_encoder = pretrainer.encoder_network
```
### Use TF-Hub models
For the usage of MobileBert TF-Hub model, please see the TF-Hub site
([English model](https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1)
or
[Multilingual model](https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1)).
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive distillation for MobileBERT student model."""
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.progressive import policies
from official.nlp import keras_nlp
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass
class LayerWiseDistillConfig(base_config.Config):
"""Defines the behavior of layerwise distillation."""
num_steps: int = 10000
warmup_steps: int = 0
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-3
decay_steps: int = 10000
hidden_distill_factor: float = 100.0
beta_distill_factor: float = 5000.0
gamma_distill_factor: float = 5.0
if_transfer_attention: bool = True
attention_distill_factor: float = 1.0
if_freeze_previous_layers: bool = False
@dataclasses.dataclass
class PretrainDistillConfig(base_config.Config):
"""Defines the behavior of pretrain distillation."""
num_steps: int = 500000
warmup_steps: int = 10000
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-7
decay_steps: int = 500000
if_use_nsp_loss: bool = True
distill_ground_truth_ratio: float = 0.5
@dataclasses.dataclass
class BertDistillationProgressiveConfig(policies.ProgressiveConfig):
"""Defines the specific distillation behavior."""
if_copy_embeddings: bool = True
layer_wise_distill_config: LayerWiseDistillConfig = LayerWiseDistillConfig()
pretrain_distill_config: PretrainDistillConfig = PretrainDistillConfig()
@dataclasses.dataclass
class BertDistillationTaskConfig(cfg.TaskConfig):
"""Defines the teacher/student model architecture and training data."""
teacher_model: bert.PretrainerConfig = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(type='mobilebert'))
student_model: bert.PretrainerConfig = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(type='mobilebert'))
# The path to the teacher model checkpoint or its directory.
teacher_model_init_checkpoint: str = ''
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
def build_sub_encoder(encoder, target_layer_id):
"""Builds an encoder that only computes first few transformer layers."""
input_ids = encoder.inputs[0]
input_mask = encoder.inputs[1]
type_ids = encoder.inputs[2]
attention_mask = keras_nlp.layers.SelfAttentionMask()(
inputs=input_ids, to_mask=input_mask)
embedding_output = encoder.embedding_layer(input_ids, type_ids)
layer_output = embedding_output
attention_score = None
for layer_idx in range(target_layer_id + 1):
layer_output, attention_score = encoder.transformer_layers[layer_idx](
layer_output, attention_mask, return_attention_scores=True)
return tf.keras.Model(
inputs=[input_ids, input_mask, type_ids],
outputs=[layer_output, attention_score])
class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
"""Distillation language modeling task progressively."""
def __init__(self,
strategy,
progressive: BertDistillationProgressiveConfig,
optimizer_config: optimization.OptimizationConfig,
task_config: BertDistillationTaskConfig,
logging_dir=None):
self._strategy = strategy
self._task_config = task_config
self._progressive_config = progressive
self._optimizer_config = optimizer_config
self._train_data_config = task_config.train_data
self._eval_data_config = task_config.validation_data
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio
if ratio < 0 or ratio > 1:
raise ValueError('distill_ground_truth_ratio has to be within [0, 1].')
# A non-trainable layer for feature normalization for transfer loss
self._layer_norm = tf.keras.layers.LayerNormalization(
axis=-1,
beta_initializer='zeros',
gamma_initializer='ones',
trainable=False)
# Build the teacher and student pretrainer model.
self._teacher_pretrainer = self._build_pretrainer(
self._task_config.teacher_model, name='teacher')
self._student_pretrainer = self._build_pretrainer(
self._task_config.student_model, name='student')
base_task.Task.__init__(
self, params=task_config, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str):
"""Builds pretrainer from config and encoder."""
encoder = encoders.build_encoder(pretrainer_cfg.encoder)
if pretrainer_cfg.cls_heads:
cls_heads = [
layers.ClassificationHead(**cfg.as_dict())
for cfg in pretrainer_cfg.cls_heads
]
else:
cls_heads = []
masked_lm = layers.MobileBertMaskedLM(
embedding_table=encoder.get_embedding_table(),
activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=pretrainer_cfg.mlm_initializer_range),
name='cls/predictions')
pretrainer = models.BertPretrainerV2(
encoder_network=encoder,
classification_heads=cls_heads,
customized_masked_lm=masked_lm,
name=name)
return pretrainer
# override policies.ProgressivePolicy
def num_stages(self):
# One stage for each layer, plus additional stage for pre-training
return self._task_config.teacher_model.encoder.mobilebert.num_blocks + 1
# override policies.ProgressivePolicy
def num_steps(self, stage_id) -> int:
"""Return the total number of steps in this stage."""
if stage_id + 1 < self.num_stages():
return self._progressive_config.layer_wise_distill_config.num_steps
else:
return self._progressive_config.pretrain_distill_config.num_steps
# override policies.ProgressivePolicy
def get_model(self, stage_id, old_model=None) -> tf.keras.Model:
del old_model
return self.build_model(stage_id)
# override policies.ProgressivePolicy
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
if stage_id + 1 < self.num_stages():
distill_config = self._progressive_config.layer_wise_distill_config
else:
distill_config = self._progressive_config.pretrain_distill_config
params = self._optimizer_config.replace(
learning_rate={
'polynomial': {
'decay_steps':
distill_config.decay_steps,
'initial_learning_rate':
distill_config.initial_learning_rate,
'end_learning_rate':
distill_config.end_learning_rate,
}
},
warmup={
'linear':
{'warmup_steps':
distill_config.warmup_steps,
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
# override policies.ProgressivePolicy
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return Dataset for this stage."""
del stage_id
if self._the_only_train_dataset is None:
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy, self.build_inputs, self._train_data_config)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
self._strategy, self.build_inputs, self._eval_data_config)
return self._the_only_eval_dataset
# override base_task.task
def build_model(self, stage_id) -> tf.keras.Model:
"""Build teacher/student keras models with outputs for current stage."""
# Freeze the teacher model.
self._teacher_pretrainer.trainable = False
layer_wise_config = self._progressive_config.layer_wise_distill_config
freeze_previous_layers = layer_wise_config.if_freeze_previous_layers
student_encoder = self._student_pretrainer.encoder_network
if stage_id != self.num_stages() - 1:
# Build a model that outputs teacher's and student's transformer outputs.
inputs = student_encoder.inputs
student_sub_encoder = build_sub_encoder(
encoder=student_encoder, target_layer_id=stage_id)
student_output_feature, student_attention_score = student_sub_encoder(
inputs)
teacher_sub_encoder = build_sub_encoder(
encoder=self._teacher_pretrainer.encoder_network,
target_layer_id=stage_id)
teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
inputs)
if freeze_previous_layers:
student_encoder.embedding_layer.trainable = False
for i in range(stage_id):
student_encoder.transformer_layers[i].trainable = False
return tf.keras.Model(
inputs=inputs,
outputs=dict(
student_output_feature=student_output_feature,
student_attention_score=student_attention_score,
teacher_output_feature=teacher_output_feature,
teacher_attention_score=teacher_attention_score))
else:
# Build a model that outputs teacher's and student's MLM/NSP outputs.
inputs = self._student_pretrainer.inputs
student_pretrainer_output = self._student_pretrainer(inputs)
teacher_pretrainer_output = self._teacher_pretrainer(inputs)
# Set all student's transformer blocks to trainable.
if freeze_previous_layers:
student_encoder.embedding_layer.trainable = True
for layer in student_encoder.transformer_layers:
layer.trainable = True
model = tf.keras.Model(
inputs=inputs,
outputs=dict(
student_pretrainer_output=student_pretrainer_output,
teacher_pretrainer_output=teacher_pretrainer_output,
))
# Checkpoint the student encoder which is the goal of distillation.
model.checkpoint_items = self._student_pretrainer.checkpoint_items
return model
# overrides base_task.Task
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
# copy from masked_lm.py for testing
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return data_loader_factory.get_data_loader(params).load(input_context)
def _get_distribution_losses(self, teacher, student):
"""Return the beta and gamma distall losses for feature distribution."""
teacher_mean = tf.math.reduce_mean(teacher, axis=-1, keepdims=True)
student_mean = tf.math.reduce_mean(student, axis=-1, keepdims=True)
teacher_var = tf.math.reduce_variance(teacher, axis=-1, keepdims=True)
student_var = tf.math.reduce_variance(student, axis=-1, keepdims=True)
beta_loss = tf.math.squared_difference(student_mean, teacher_mean)
beta_loss = tf.math.reduce_mean(beta_loss, axis=None, keepdims=False)
gamma_loss = tf.math.abs(student_var - teacher_var)
gamma_loss = tf.math.reduce_mean(gamma_loss, axis=None, keepdims=False)
return beta_loss, gamma_loss
def _get_attention_loss(self, teacher_score, student_score):
# Note that the definition of KLDivergence here is a little different from
# the original one (tf.keras.losses.KLDivergence). We adopt this approach
# to stay consistent with the TF1 implementation.
teacher_weight = tf.keras.activations.softmax(teacher_score, axis=-1)
student_log_weight = tf.nn.log_softmax(student_score, axis=-1)
kl_divergence = -(teacher_weight * student_log_weight)
kl_divergence = tf.math.reduce_sum(kl_divergence, axis=-1, keepdims=True)
kl_divergence = tf.math.reduce_mean(kl_divergence, axis=None,
keepdims=False)
return kl_divergence
def build_losses(self, labels, outputs, metrics) -> tf.Tensor:
"""Builds losses and update loss-related metrics for the current stage."""
last_stage = 'student_pretrainer_output' in outputs
# Layer-wise warmup stage
if not last_stage:
distill_config = self._progressive_config.layer_wise_distill_config
teacher_feature = outputs['teacher_output_feature']
student_feature = outputs['student_output_feature']
feature_transfer_loss = tf.keras.losses.mean_squared_error(
self._layer_norm(teacher_feature), self._layer_norm(student_feature))
feature_transfer_loss *= distill_config.hidden_distill_factor
beta_loss, gamma_loss = self._get_distribution_losses(teacher_feature,
student_feature)
beta_loss *= distill_config.beta_distill_factor
gamma_loss *= distill_config.gamma_distill_factor
total_loss = feature_transfer_loss + beta_loss + gamma_loss
if distill_config.if_transfer_attention:
teacher_attention = outputs['teacher_attention_score']
student_attention = outputs['student_attention_score']
attention_loss = self._get_attention_loss(teacher_attention,
student_attention)
attention_loss *= distill_config.attention_distill_factor
total_loss += attention_loss
total_loss /= tf.cast((self._stage_id + 1), tf.float32)
# Last stage to distill pretraining layer.
else:
distill_config = self._progressive_config.pretrain_distill_config
lm_label = labels['masked_lm_ids']
vocab_size = (
self._task_config.student_model.encoder.mobilebert.word_vocab_size)
# Shape: [batch, max_predictions_per_seq, vocab_size]
lm_label = tf.one_hot(indices=lm_label, depth=vocab_size, on_value=1.0,
off_value=0.0, axis=-1, dtype=tf.float32)
lm_label_weights = labels['masked_lm_weights']
gt_ratio = distill_config.distill_ground_truth_ratio
if gt_ratio != 1.0:
teacher_mlm_logits = outputs['teacher_pretrainer_output']['mlm_logits']
teacher_labels = tf.nn.softmax(teacher_mlm_logits, axis=-1)
lm_label = gt_ratio * lm_label + (1-gt_ratio) * teacher_labels
student_pretrainer_output = outputs['student_pretrainer_output']
# Shape: [batch, max_predictions_per_seq, vocab_size]
student_lm_log_probs = tf.nn.log_softmax(
student_pretrainer_output['mlm_logits'], axis=-1)
# Shape: [batch * max_predictions_per_seq]
per_example_loss = tf.reshape(
-tf.reduce_sum(student_lm_log_probs * lm_label, axis=[-1]), [-1])
lm_label_weights = tf.reshape(labels['masked_lm_weights'], [-1])
lm_numerator_loss = tf.reduce_sum(per_example_loss * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
total_loss = mlm_loss
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
student_pretrainer_output['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
total_loss += sentence_loss
# Also update loss-related metrics here, instead of in `process_metrics`.
metrics = dict([(metric.name, metric) for metric in metrics])
if not last_stage:
metrics['feature_transfer_mse'].update_state(feature_transfer_loss)
metrics['beta_transfer_loss'].update_state(beta_loss)
metrics['gamma_transfer_loss'].update_state(gamma_loss)
layer_wise_config = self._progressive_config.layer_wise_distill_config
if layer_wise_config.if_transfer_attention:
metrics['attention_transfer_loss'].update_state(attention_loss)
else:
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
metrics['next_sentence_loss'].update_state(sentence_loss)
metrics['total_loss'].update_state(total_loss)
return total_loss
# overrides base_task.Task
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.Mean(name='feature_transfer_mse'),
tf.keras.metrics.Mean(name='beta_transfer_loss'),
tf.keras.metrics.Mean(name='gamma_transfer_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss'),
tf.keras.metrics.Mean(name='total_loss')]
if self._progressive_config.layer_wise_distill_config.if_transfer_attention:
metrics.append(tf.keras.metrics.Mean(name='attention_transfer_loss'))
if self._task_config.train_data.use_next_sentence_label:
metrics.append(tf.keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics
# overrides base_task.Task
# process non-loss metrics
def process_metrics(self, metrics, labels, student_pretrainer_output):
metrics = dict([(metric.name, metric) for metric in metrics])
# Final pretrainer layer distillation stage.
if student_pretrainer_output is not None:
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'], student_pretrainer_output['mlm_logits'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'],
student_pretrainer_output['next_sentence'])
# overrides base_task.Task
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
outputs=outputs,
metrics=metrics)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# get trainable variables for current stage
tvars = model.trainable_variables
last_stage = 'student_pretrainer_output' in outputs
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(
metrics, inputs,
outputs['student_pretrainer_output'] if last_stage else None)
return {self.loss: loss}
# overrides base_task.Task
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
# Computes per-replica loss.
loss = self.build_losses(labels=inputs, outputs=outputs, metrics=metrics)
last_stage = 'student_pretrainer_output' in outputs
self.process_metrics(
metrics, inputs,
outputs['student_pretrainer_output'] if last_stage else None)
return {self.loss: loss}
@property
def cur_checkpoint_items(self):
"""Checkpoints for model, stage_id, optimizer for preemption handling."""
return dict(
stage_id=self._stage_id,
volatiles=self._volatiles,
student_pretrainer=self._student_pretrainer,
teacher_pretrainer=self._teacher_pretrainer,
encoder=self._student_pretrainer.encoder_network)
def initialize(self, model):
"""Loads teacher's pretrained checkpoint and copy student's embedding."""
# This function will be called when no checkpoint found for the model,
# i.e., when the training starts (not preemption case).
# The weights of teacher pretrainer and student pretrainer will be
# initialized, rather than the passed-in `model`.
del model
logging.info('Begin to load checkpoint for teacher pretrainer model.')
ckpt_dir_or_file = self._task_config.teacher_model_init_checkpoint
if not ckpt_dir_or_file:
raise ValueError('`teacher_model_init_checkpoint` is not specified.')
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Makes sure the teacher pretrainer variables are created.
_ = self._teacher_pretrainer(self._teacher_pretrainer.inputs)
teacher_checkpoint = tf.train.Checkpoint(
**self._teacher_pretrainer.checkpoint_items)
teacher_checkpoint.read(ckpt_dir_or_file).assert_existing_objects_matched()
logging.info('Begin to copy word embedding from teacher model to student.')
teacher_encoder = self._teacher_pretrainer.encoder_network
student_encoder = self._student_pretrainer.encoder_network
embedding_weights = teacher_encoder.embedding_layer.get_weights()
student_encoder.embedding_layer.set_weights(embedding_weights)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.mobilebert.distillation."""
import os
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling import tf_utils
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.modeling import models
from official.nlp.projects.mobilebert import distillation
class DistillationTest(tf.test.TestCase):
def setUp(self):
super(DistillationTest, self).setUp()
# using small model for testing
self.model_block_num = 2
self.task_config = distillation.BertDistillationTaskConfig(
teacher_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
],
mlm_activation='gelu'),
student_model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='mobilebert',
mobilebert=encoders.MobileBertEncoderConfig(
num_blocks=self.model_block_num)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=256,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
],
mlm_activation='relu'),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy',
max_predictions_per_seq=76,
seq_length=512,
global_batch_size=10),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
input_path='dummy',
max_predictions_per_seq=76,
seq_length=512,
global_batch_size=10))
# set only 1 step for each stage
progressive_config = distillation.BertDistillationProgressiveConfig()
progressive_config.layer_wise_distill_config.num_steps = 1
progressive_config.pretrain_distill_config.num_steps = 1
optimization_config = optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(
type='lamb',
lamb=optimization.LAMBConfig(
weight_decay_rate=0.0001,
exclude_from_weight_decay=[
'LayerNorm', 'layer_norm', 'bias', 'no_norm'
])),
learning_rate=optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1.5e-3,
decay_steps=10000,
end_learning_rate=1.5e-3)),
warmup=optimization.WarmupConfig(
type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
self.exp_config = cfg.ExperimentConfig(
task=self.task_config,
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=progressive_config,
optimizer_config=optimization_config))
# Create a teacher model checkpoint.
teacher_encoder = encoders.build_encoder(
self.task_config.teacher_model.encoder)
pretrainer_config = self.task_config.teacher_model
if pretrainer_config.cls_heads:
teacher_cls_heads = [
layers.ClassificationHead(**cfg.as_dict())
for cfg in pretrainer_config.cls_heads
]
else:
teacher_cls_heads = []
masked_lm = layers.MobileBertMaskedLM(
embedding_table=teacher_encoder.get_embedding_table(),
activation=tf_utils.get_activation(pretrainer_config.mlm_activation),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=pretrainer_config.mlm_initializer_range),
name='cls/predictions')
teacher_pretrainer = models.BertPretrainerV2(
encoder_network=teacher_encoder,
classification_heads=teacher_cls_heads,
customized_masked_lm=masked_lm)
# The model variables will be created after the forward call.
_ = teacher_pretrainer(teacher_pretrainer.inputs)
teacher_pretrainer_ckpt = tf.train.Checkpoint(
**teacher_pretrainer.checkpoint_items)
teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt')
teacher_pretrainer_ckpt.save(teacher_ckpt_path)
self.task_config.teacher_model_init_checkpoint = self.get_temp_dir()
def test_task(self):
bert_distillation_task = distillation.BertDistillationTask(
strategy=tf.distribute.get_strategy(),
progressive=self.exp_config.trainer.progressive,
optimizer_config=self.exp_config.trainer.optimizer_config,
task_config=self.task_config)
metrics = bert_distillation_task.build_metrics()
train_dataset = bert_distillation_task.get_train_dataset(stage_id=0)
train_iterator = iter(train_dataset)
eval_dataset = bert_distillation_task.get_eval_dataset(stage_id=0)
eval_iterator = iter(eval_dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
# test train/val step for all stages, including the last pretraining stage
for stage in range(self.model_block_num + 1):
step = stage
bert_distillation_task.update_pt_stage(step)
model = bert_distillation_task.get_model(stage, None)
bert_distillation_task.initialize(model)
bert_distillation_task.train_step(next(train_iterator), model, optimizer,
metrics=metrics)
bert_distillation_task.validation_step(next(eval_iterator), model,
metrics=metrics)
logging.info('begin to save and load model checkpoint')
ckpt = tf.train.Checkpoint(model=model)
ckpt.save(self.get_temp_dir())
if __name__ == '__main__':
tf.test.main()
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 512
hidden_activation: relu
hidden_dropout_prob: 0.0
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 128
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 4096
hidden_activation: gelu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 1024
initializer_range: 0.02
key_query_shared_bottleneck: false
num_feedforward_networks: 1
normalization_type: layer_norm
classifier_activation: false
task:
train_data:
drop_remainder: true
global_batch_size: 2048
input_path: ""
is_training: true
max_predictions_per_seq: 20
seq_length: 512
use_next_sentence_label: true
use_position_id: false
validation_data:
drop_remainder: true
global_batch_size: 2048
input_path: ""
is_training: false
max_predictions_per_seq: 20
seq_length: 512
use_next_sentence_label: true
use_position_id: false
teacher_model:
cls_heads: []
mlm_activation: gelu
mlm_initializer_range: 0.02
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 4096
hidden_activation: gelu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 1024
initializer_range: 0.02
key_query_shared_bottleneck: false
num_feedforward_networks: 1
normalization_type: layer_norm
classifier_activation: false
student_model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.0, inner_dim: 512,
name: next_sentence, num_classes: 2}]
mlm_activation: relu
mlm_initializer_range: 0.02
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 512
hidden_activation: relu
hidden_dropout_prob: 0.0
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 128
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
teacher_model_init_checkpoint: ""
trainer:
progressive:
if_copy_embeddings: true
layer_wise_distill_config:
num_steps: 10000
pretrain_distill_config:
num_steps: 500000
decay_steps: 500000
train_steps: 740000
max_to_keep: 10
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export the MobileBERT encoder model as a TF-Hub SavedModel."""
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.projects.mobilebert import model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string(
"bert_config_file", None,
"Bert configuration file to define core mobilebert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", True, "Whether to lowercase.")
def create_mobilebert_model(bert_config):
"""Creates a model for exporting to tfhub."""
pretrainer = model_utils.create_mobilebert_pretrainer(bert_config)
encoder = pretrainer.encoder_network
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations,
# add "default" as an alias for MobileBERT's whole-input reptesentations.
encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
core_model = tf.keras.Model(
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
mlm_model = tf.keras.Model(
inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
# Set `_auto_track_sub_layers` to False, so that the additional weights
# from `mlm` sub-object will not be included in the core model.
# TODO(b/169210253): Use public API after the bug is resolved.
core_model._auto_track_sub_layers = False # pylint: disable=protected-access
core_model.mlm = mlm_model
return core_model, pretrainer
def export_bert_tfhub(bert_config, model_checkpoint_path, hub_destination,
vocab_file, do_lower_case):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, pretrainer = create_mobilebert_model(bert_config)
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
logging.info("Begin to load model")
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
logging.info("Loading model finished")
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
logging.info("Begin to save files for tfhub at %s", hub_destination)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
logging.info("tfhub files exported!")
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
bert_config = model_utils.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file, FLAGS.do_lower_case)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import copy
import json
import tensorflow.compat.v1 as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
from official.nlp.modeling import networks
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
embedding_size=None,
trigram_input=False,
use_bottleneck=False,
intra_bottleneck_size=None,
use_bottleneck_attention=False,
key_query_shared_bottleneck=False,
num_feedforward_networks=1,
normalization_type="layer_norm",
classifier_activation=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: The size of the token embedding.
trigram_input: Use a convolution of trigram as input.
use_bottleneck: Use the bottleneck/inverted-bottleneck structure in BERT.
intra_bottleneck_size: The hidden size in the bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation.
key_query_shared_bottleneck: Use the same linear transformation for
query&key in the bottleneck.
num_feedforward_networks: Number of FFNs in a block.
normalization_type: The normalization type in BERT.
classifier_activation: Using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.embedding_size = embedding_size
self.trigram_input = trigram_input
self.use_bottleneck = use_bottleneck
self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type
self.classifier_activation = classifier_activation
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in json_object.items():
config.__dict__[key] = value
if config.embedding_size is None:
config.embedding_size = config.hidden_size
if config.intra_bottleneck_size is None:
config.intra_bottleneck_size = config.hidden_size
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def create_mobilebert_pretrainer(bert_config):
"""Creates a BertPretrainerV2 that wraps MobileBERTEncoder model."""
mobilebert_encoder = networks.MobileBERTEncoder(
word_vocab_size=bert_config.vocab_size,
word_embed_size=bert_config.embedding_size,
type_vocab_size=bert_config.type_vocab_size,
max_sequence_length=bert_config.max_position_embeddings,
num_blocks=bert_config.num_hidden_layers,
hidden_size=bert_config.hidden_size,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_act_fn=tf_utils.get_activation(bert_config.hidden_act),
hidden_dropout_prob=bert_config.hidden_dropout_prob,
attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob,
intra_bottleneck_size=bert_config.intra_bottleneck_size,
initializer_range=bert_config.initializer_range,
use_bottleneck_attention=bert_config.use_bottleneck_attention,
key_query_shared_bottleneck=bert_config.key_query_shared_bottleneck,
num_feedforward_networks=bert_config.num_feedforward_networks,
normalization_type=bert_config.normalization_type,
classifier_activation=bert_config.classifier_activation)
masked_lm = layers.MobileBertMaskedLM(
embedding_table=mobilebert_encoder.get_embedding_table(),
activation=tf_utils.get_activation(bert_config.hidden_act),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
name="cls/predictions")
pretrainer = models.BertPretrainerV2(
encoder_network=mobilebert_encoder, customized_masked_lm=masked_lm)
# Makes sure the pretrainer variables are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
"""Creating the task and start trainer."""
import pprint
from absl import app
from absl import flags
from absl import logging
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import config_definitions as cfg
from official.core import train_utils
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling import performance
from official.modeling.progressive import train_lib
from official.modeling.progressive import trainer as prog_trainer_lib
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.mobilebert import distillation
FLAGS = flags.FLAGS
optimization_config = optimization.OptimizationConfig(
optimizer=optimization.OptimizerConfig(
type='lamb',
lamb=optimization.LAMBConfig(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'bias', 'norm'],
clipnorm=1.0)),
learning_rate=optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1.5e-3,
decay_steps=10000,
end_learning_rate=1.5e-3)),
warmup=optimization.WarmupConfig(
type='linear',
linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))
# copy from progressive/utils.py due to the private visibility issue.
def config_override(params, flags_obj):
"""Override ExperimentConfig according to flags."""
# Change runtime.tpu to the real tpu.
params.override({
'runtime': {
'tpu': flags_obj.tpu,
}
})
# Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for config_file in flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override:
params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True)
params.validate()
params.lock()
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict()))
model_dir = flags_obj.model_dir
if 'train' in flags_obj.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
return params
def get_exp_config():
"""Get ExperimentConfig."""
params = cfg.ExperimentConfig(
task=distillation.BertDistillationTaskConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
progressive=distillation.BertDistillationProgressiveConfig(),
optimizer_config=optimization_config,
train_steps=740000,
checkpoint_interval=20000))
return config_override(params, FLAGS)
def main(_):
logging.info('Parsing config files...')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = get_exp_config()
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = distillation.BertDistillationTask(
strategy=distribution_strategy,
progressive=params.trainer.progressive,
optimizer_config=params.trainer.optimizer_config,
task_config=params.task)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=FLAGS.model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from official.nlp.projects.mobilebert import model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string(
"bert_config_file", None,
"Bert configuration file to define core mobilebert layers.")
flags.DEFINE_string("tf1_checkpoint_path", None,
"Path to load tf1 checkpoint.")
flags.DEFINE_string("tf2_checkpoint_path", None,
"Path to save tf2 checkpoint.")
flags.DEFINE_boolean("use_model_prefix", False,
("If use model name as prefix for variables. Turn this"
"flag on when the converted checkpoint is used for model"
"in subclass implementation, which uses the model name as"
"prefix for all variable names."))
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "attention/attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "attention/attention_output/bias" in name:
return shape
patterns = [
"attention/query", "attention/value", "attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def convert(checkpoint_from_path,
checkpoint_to_path,
name_replacements,
permutations,
bert_config,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
bert_config: A `BertConfig` to create the core model.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
last_ffn_layer_id = str(bert_config.num_feedforward_networks - 1)
name_replacements = [
(x[0], x[1].replace("LAST_FFN_LAYER_ID", last_ffn_layer_id))
for x in name_replacements
]
output_dir, _ = os.path.split(checkpoint_to_path)
tf.io.gfile.makedirs(output_dir)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
with tf.Graph().as_default():
logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if bert_config.num_attention_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape,
bert_config.num_attention_heads)
if new_shape:
logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
logging.info("Writing checkpoint_to_path %s", temporary_checkpoint)
saver.save(sess, temporary_checkpoint, write_meta_graph=False)
logging.info("Summary:")
logging.info("Converted %d variable name(s).", len(new_variable_map))
logging.info("Converted: %s", str(conversion_map))
mobilebert_model = model_utils.create_mobilebert_pretrainer(bert_config)
create_v2_checkpoint(
mobilebert_model, temporary_checkpoint, checkpoint_to_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def create_v2_checkpoint(model, src_checkpoint, output_path):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(**model.checkpoint_items)
checkpoint.save(output_path)
_NAME_REPLACEMENT = [
# prefix path replacement
("bert/", "mobile_bert_encoder/"),
("encoder/layer_", "transformer_layer_"),
# embedding layer
("embeddings/embedding_transformation",
"mobile_bert_embedding/embedding_projection"),
("embeddings/position_embeddings",
"mobile_bert_embedding/position_embedding/embeddings"),
("embeddings/token_type_embeddings",
"mobile_bert_embedding/type_embedding/embeddings"),
("embeddings/word_embeddings",
"mobile_bert_embedding/word_embedding/embeddings"),
("embeddings/FakeLayerNorm", "mobile_bert_embedding/embedding_norm"),
("embeddings/LayerNorm", "mobile_bert_embedding/embedding_norm"),
# attention layer
("attention/output/dense", "attention/attention_output"),
("attention/output/FakeLayerNorm", "attention/norm"),
("attention/output/LayerNorm", "attention/norm"),
("attention/self", "attention"),
# input bottleneck
("bottleneck/input/dense", "bottleneck_input/dense"),
("bottleneck/input/FakeLayerNorm", "bottleneck_input/norm"),
("bottleneck/input/LayerNorm", "bottleneck_input/norm"),
("bottleneck/attention/dense", "kq_shared_bottleneck/dense"),
("bottleneck/attention/FakeLayerNorm", "kq_shared_bottleneck/norm"),
("bottleneck/attention/LayerNorm", "kq_shared_bottleneck/norm"),
# ffn layer
("ffn_layer_0/output/dense", "ffn_layer_0/output_dense"),
("ffn_layer_1/output/dense", "ffn_layer_1/output_dense"),
("ffn_layer_2/output/dense", "ffn_layer_2/output_dense"),
("output/dense", "ffn_layer_LAST_FFN_LAYER_ID/output_dense"),
("ffn_layer_0/output/FakeLayerNorm", "ffn_layer_0/norm"),
("ffn_layer_0/output/LayerNorm", "ffn_layer_0/norm"),
("ffn_layer_1/output/FakeLayerNorm", "ffn_layer_1/norm"),
("ffn_layer_1/output/LayerNorm", "ffn_layer_1/norm"),
("ffn_layer_2/output/FakeLayerNorm", "ffn_layer_2/norm"),
("ffn_layer_2/output/LayerNorm", "ffn_layer_2/norm"),
("output/FakeLayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
("output/LayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
("ffn_layer_0/intermediate/dense", "ffn_layer_0/intermediate_dense"),
("ffn_layer_1/intermediate/dense", "ffn_layer_1/intermediate_dense"),
("ffn_layer_2/intermediate/dense", "ffn_layer_2/intermediate_dense"),
("intermediate/dense", "ffn_layer_LAST_FFN_LAYER_ID/intermediate_dense"),
# output bottleneck
("output/bottleneck/FakeLayerNorm", "bottleneck_output/norm"),
("output/bottleneck/LayerNorm", "bottleneck_output/norm"),
("output/bottleneck/dense", "bottleneck_output/dense"),
# pooler layer
("pooler/dense", "pooler"),
# MLM layer
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias")
]
_EXCLUDE_PATTERNS = ["cls/seq_relationship", "global_step"]
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
if not FLAGS.use_model_prefix:
_NAME_REPLACEMENT[0] = ("bert/", "")
bert_config = model_utils.BertConfig.from_json_file(FLAGS.bert_config_file)
convert(FLAGS.tf1_checkpoint_path,
FLAGS.tf2_checkpoint_path,
_NAME_REPLACEMENT,
[],
bert_config,
_EXCLUDE_PATTERNS)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions."""
import numpy as np
def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
"""Generate consistent fake integer input sequences."""
np.random.seed(seed)
fake_input = []
for _ in range(batch_size):
fake_input.append([])
for _ in range(seq_len):
fake_input[-1].append(np.random.randint(0, vocab_size))
fake_input = np.asarray(fake_input)
return fake_input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment