"packaging/vscode:/vscode.git/clone" did not exist on "0a701058b432dd602bba3461866bfb3c3a352e04"
Unverified Commit 7a45b513 authored by Vishnu Banna's avatar Vishnu Banna Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into exp_pr2

parents 54115e16 12bbefce
...@@ -27,7 +27,7 @@ task: ...@@ -27,7 +27,7 @@ task:
intermediate_size: 3072 intermediate_size: 3072
max_position_embeddings: 512 max_position_embeddings: 512
num_attention_heads: 12 num_attention_heads: 12
num_layers: 6 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
train_data: train_data:
...@@ -39,6 +39,7 @@ task: ...@@ -39,6 +39,7 @@ task:
seq_length: 512 seq_length: 512
use_next_sentence_label: false use_next_sentence_label: false
use_position_id: false use_position_id: false
cycle_length: 8
validation_data: validation_data:
drop_remainder: true drop_remainder: true
global_batch_size: 256 global_batch_size: 256
......
...@@ -39,6 +39,7 @@ task: ...@@ -39,6 +39,7 @@ task:
seq_length: 512 seq_length: 512
use_next_sentence_label: false use_next_sentence_label: false
use_position_id: false use_position_id: false
cycle_length: 8
validation_data: validation_data:
drop_remainder: true drop_remainder: true
global_batch_size: 256 global_batch_size: 256
......
...@@ -51,9 +51,7 @@ class TeamsPretrainerConfig(base_config.Config): ...@@ -51,9 +51,7 @@ class TeamsPretrainerConfig(base_config.Config):
@gin.configurable @gin.configurable
def get_encoder(bert_config, def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
embedding_network=None,
hidden_layers=layers.Transformer):
"""Gets a 'EncoderScaffold' object. """Gets a 'EncoderScaffold' object.
Args: Args:
...@@ -85,7 +83,9 @@ def get_encoder(bert_config, ...@@ -85,7 +83,9 @@ def get_encoder(bert_config,
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
) )
if embedding_network is None: if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg) embedding_network = networks.PackedSequenceEmbedding
if hidden_layers is None:
hidden_layers = layers.Transformer
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
embedding_cls=embedding_network, embedding_cls=embedding_network,
......
# MobileBERT-EdgeTPU
<figure align="center">
<img width=70% src=https://storage.googleapis.com/tf_model_garden/models/edgetpu/images/readme-mobilebert.png>
<figcaption>Performance of MobileBERT-EdgeTPU models on the SQuAD v1.1 dataset.</figcaption>
</figure>
Note: For MobileBERT baseline float model, NNAPI delegates parts of the
computing ops to CPU, making the latency much higher.
Note: The accuracy numbers for BERT_base and BERT_large are from the
[training results](https://arxiv.org/abs/1810.04805). These models are too large
and not feasible to run on device.
Deploying low-latency, high-quality transformer based language models on device
is highly desirable, and can potentially benefit multiple applications such as
automatic speech recognition (ASR), translation, sentence autocompletion, and
even some vision tasks. By co-designing the neural networks with the Edge TPU
hardware accelerator in Google Tensor SoC, we have built EdgeTPU-customized
MobileBERT models that demonstrate datacenter model quality meanwhile
outperforms baseline MobileBERT's latency.
We set up our model architecture search space based on
[MobileBERT](https://arxiv.org/abs/2004.02984) and leverage AutoML algorithms to
find models with up to 2x better hardware utilization. With higher utilization,
we are able to bring larger and more accurate models on chip, and meanwhile the
models can still outperform the baseline MobileBERT latency. We built a
customized distillation training pipeline and performed exhaustive
hyperparameters (e.g. learning rate, dropout ratio, etc) search to achieve the
best accuracy. As shown in the above figure, the quantized MobileBERT-EdgeTPU
models establish a new pareto-frontier for the question answering tasks and also
exceed the accuracy of the float BERT_base model which is 400+MB and too large
to run on edge devices.
We also observed that, unlike most vision models, the accuracy drops
significantly for MobileBERT/MobileBERT-EdgeTPU with plain post training
quantization (PTQ) or quantization aware training (QAT). Proper model
modifications, such as clipping the mask value, are necessary to retain the
accuracy for a quantized model. Therefore, as an alternative to the quant
models, we also provide a set of Edge TPU friendly float models which also
produce a better (though marginally) roofline than the baseline MobileBERT quant
model. Notably, the float MobileBERT-EdgeTPU-M model yields accuracy that is
even close to the BERT_large, which has 1.3GB model size in float precision.
Quantization now becomes an optional optimization rather than a prerequisite,
which can greatly benefit/unblock some use cases where quantization is
infeasible or introduce large accuracy deterioration, and potentially reduce the
time-to-market.
## Pre-trained Models
Model name | # Parameters | # Ops | MLM | Checkpoint | TFhub link
--------------------- | :----------: | :----: | :---: | :---: | :--------:
MobileBERT-EdgeTPU-M | 50.9M | 18.8e9 | 73.8% | WIP | WIP
MobileBERT-EdgeTPU-S | 38.3M | 14.0e9 | 72.8% | WIP | WIP
MobileBERT-EdgeTPU-XS | 27.1M | 9.4e9 | 71.2% | WIP | WIP
### Restoring from Checkpoints
To load the pre-trained MobileBERT checkpoint in your code, please follow the
example below or check the `serving/export_tflite_squad` module:
```python
import tensorflow as tf
from official.nlp.projects.mobilebert_edgetpu import params
bert_config_file = ...
model_checkpoint_path = ...
# Set up experiment params and load the configs from file/files.
experiment_params = params.EdgeTPUBERTCustomParams()
# change the input mask type to tf.float32 to avoid additional casting op.
experiment_params.student_model.encoder.mobilebert.input_mask_dtype = 'float32'
pretrainer_model = model_builder.build_bert_pretrainer(
experiment_params.student_model,
name='pretrainer',
quantization_friendly=True)
checkpoint_dict = {'model': pretrainer_model}
checkpoint = tf.train.Checkpoint(**checkpoint_dict)
checkpoint.restore(FLAGS.model_checkpoint).assert_existing_objects_matched()
```
### Use TF-Hub models
TODO(longy): Update with instructions to use tf-hub models
...@@ -12,8 +12,3 @@ ...@@ -12,8 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Masked language model network."""
from official.nlp.modeling import layers
MaskedLM = layers.MaskedLM
...@@ -12,5 +12,3 @@ ...@@ -12,5 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Keras-NLP layers package definition."""
from official.nlp.keras_nlp.encoders.bert_encoder import BertEncoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datastructures for all the configurations for MobileBERT-EdgeTPU training."""
import dataclasses
from typing import Optional
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader
DatasetParams = pretrain_dataloader.BertPretrainDataConfig
PretrainerModelParams = bert.PretrainerConfig
@dataclasses.dataclass
class OrbitParams(base_config.Config):
"""Parameters that setup Orbit training/evaluation pipeline.
Attributes:
mode: Orbit controller mode, can be 'train', 'train_and_evaluate', or
'evaluate'.
steps_per_loop: The number of steps to run in each inner loop of training.
total_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If -1, this
method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evaluations. If
set, training will always stop every `eval_interval` steps, even if this
results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
"""
mode: str = 'train'
steps_per_loop: int = 1000
total_steps: int = 1000000
eval_steps: int = -1
eval_interval: Optional[int] = None
@dataclasses.dataclass
class OptimizerParams(optimization.OptimizationConfig):
"""Optimizer parameters for MobileBERT-EdgeTPU."""
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type='adamw',
adamw=optimization.AdamWeightDecayConfig(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']))
learning_rate: optimization.LrConfig = optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type='polynomial',
polynomial=optimization.PolynomialWarmupConfig(warmup_steps=10000))
@dataclasses.dataclass
class RuntimeParams(base_config.Config):
"""Parameters that set up the training runtime.
TODO(longy): Can reuse the Runtime Config in:
official/core/config_definitions.py
Attributes
distribution_strategy: Keras distribution strategy
use_gpu: Whether to use GPU
use_tpu: Whether to use TPU
num_gpus: Number of gpus to use for training
num_workers: Number of parallel workers
tpu_address: The bns address of the TPU to use.
"""
distribution_strategy: str = 'off'
num_gpus: Optional[int] = 0
all_reduce_alg: Optional[str] = None
num_workers: int = 1
tpu_address: str = ''
use_gpu: Optional[bool] = None
use_tpu: Optional[bool] = None
@dataclasses.dataclass
class LayerWiseDistillationParams(base_config.Config):
"""Define the behavior of layer-wise distillation.
Layer-wise distillation is an optional step where the knowledge is transferred
layerwisely for all the transformer layers. The end-to-end distillation is
performed after layer-wise distillation if layer-wise distillation steps is
not zero.
"""
num_steps: int = 10000
warmup_steps: int = 10000
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-3
decay_steps: int = 10000
hidden_distill_factor: float = 100.0
beta_distill_factor: float = 5000.0
gamma_distill_factor: float = 5.0
attention_distill_factor: float = 1.0
@dataclasses.dataclass
class EndToEndDistillationParams(base_config.Config):
"""Define the behavior of end2end pretrainer distillation."""
num_steps: int = 580000
warmup_steps: int = 20000
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-7
decay_steps: int = 580000
distill_ground_truth_ratio: float = 0.5
@dataclasses.dataclass
class EdgeTPUBERTCustomParams(base_config.Config):
"""EdgeTPU-BERT custom params.
Attributes:
train_dataset: An instance of the DatasetParams.
eval_dataset: An instance of the DatasetParams.
teacher_model: An instance of the PretrainerModelParams. If None, then the
student model is trained independently without distillation.
student_model: An instance of the PretrainerModelParams
teacher_model_init_checkpoint: Path for the teacher model init checkpoint.
student_model_init_checkpoint: Path for the student model init checkpoint.
layer_wise_distillation: Distillation config for the layer-wise step.
end_to_end_distillation: Distillation config for the end2end step.
optimizer: An instance of the OptimizerParams.
runtime: An instance of the RuntimeParams.
learning_rate: An instance of the LearningRateParams.
orbit_config: An instance of the OrbitParams.
distill_ground_truth_ratio: A float number representing the ratio between
distillation output and ground truth.
"""
train_datasest: DatasetParams = DatasetParams()
eval_dataset: DatasetParams = DatasetParams()
teacher_model: Optional[PretrainerModelParams] = PretrainerModelParams()
student_model: PretrainerModelParams = PretrainerModelParams()
teacher_model_init_checkpoint: str = ''
student_model_init_checkpoint: str = ''
layer_wise_distillation: LayerWiseDistillationParams = (
LayerWiseDistillationParams())
end_to_end_distillation: EndToEndDistillationParams = (
EndToEndDistillationParams())
optimizer: OptimizerParams = OptimizerParams()
runtime: RuntimeParams = RuntimeParams()
orbit_config: OrbitParams = OrbitParams()
task:
# hub_module_url: 'gs://**/panzf/mobilebert/tfhub/'
init_checkpoint: 'gs://**/edgetpu_bert/edgetpu_bert_float_candidate_13_e2e_820k/exported_ckpt/'
model:
num_classes: 3
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: gs://**/yo/bert/glue/tfrecords/MNLI/MNLI_matched_train.tf_record
is_training: true
seq_length: 128
label_type: 'int'
validation_data:
drop_remainder: false
global_batch_size: 32
input_path: gs://**/yo/bert/glue/tfrecords/MNLI/MNLI_matched_eval.tf_record
is_training: false
seq_length: 128
label_type: 'int'
trainer:
checkpoint_interval: 10000
optimizer_config:
learning_rate:
polynomial:
# 100% of train_steps.
decay_steps: 50000
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
# ~10% of train_steps.
warmup_steps: 5000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
# Training data size 392,702 examples, 8 epochs.
train_steps: 50000
validation_interval: 2000
# Eval data size = 9815 examples.
validation_steps: 307
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'cls_accuracy'
best_checkpoint_metric_comp: 'higher'
# MobileBERT model from https://arxiv.org/abs/2004.02984.
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 24
hidden_size: 512
num_attention_heads: 4
intermediate_size: 512
hidden_activation: relu
hidden_dropout_prob: 0.0
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 128
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
# MobileBERT-EdgeTPU model.
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 12
hidden_size: 512
num_attention_heads: 4
intermediate_size: 1024
hidden_activation: relu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 256
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 6
normalization_type: no_norm
classifier_activation: false
# MobileBERT-EdgeTPU-S model.
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 12
hidden_size: 512
num_attention_heads: 4
intermediate_size: 1024
hidden_activation: relu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 256
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
# MobileBERT-EdgeTPU-XS model.
task:
model:
encoder:
type: mobilebert
mobilebert:
word_vocab_size: 30522
word_embed_size: 128
type_vocab_size: 2
max_sequence_length: 512
num_blocks: 8
hidden_size: 512
num_attention_heads: 4
intermediate_size: 1024
hidden_activation: relu
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
intra_bottleneck_size: 256
initializer_range: 0.02
key_query_shared_bottleneck: true
num_feedforward_networks: 4
normalization_type: no_norm
classifier_activation: false
task:
# hub_module_url: 'gs://**/panzf/mobilebert/tfhub/'
max_answer_length: 30
n_best_size: 20
null_score_diff_threshold: 0.0
init_checkpoint: 'gs://**/edgetpu_bert/edgetpu_bert_float_candidate_13_e2e_820k/exported_ckpt/'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: gs://**/tp/bert/squad_v1.1/train.tf_record
is_training: true
seq_length: 384
validation_data:
do_lower_case: true
doc_stride: 128
drop_remainder: false
global_batch_size: 48
input_path: gs://**/squad/dev-v1.1.json
is_training: false
query_length: 64
seq_length: 384
tokenization: WordPiece
version_2_with_negative: false
vocab_file: gs://**/panzf/ttl-30d/mobilebert/tf2_checkpoint/vocab.txt
trainer:
checkpoint_interval: 1000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
decay_steps: 19420
end_learning_rate: 0.0
initial_learning_rate: 8.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
# 10% of total training steps
warmup_steps: 1942
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
# 7 epochs for training
train_steps: 19420
validation_interval: 3000
validation_steps: 226
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'final_f1'
best_checkpoint_metric_comp: 'higher'
# Distillation pretraining for Mobilebert.
# The final MLM accuracy is around 70.8% for e2e only training and 71.4% for layer-wise + e2e.
layer_wise_distillation:
num_steps: 10000
warmup_steps: 0
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-3
decay_steps: 10000
end_to_end_distillation:
num_steps: 585000
warmup_steps: 20000
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-7
decay_steps: 585000
distill_ground_truth_ratio: 0.5
optimizer:
optimizer:
lamb:
beta_1: 0.9
beta_2: 0.999
clipnorm: 1.0
epsilon: 1.0e-06
exclude_from_layer_adaptation: null
exclude_from_weight_decay: ['LayerNorm', 'bias', 'norm']
global_clipnorm: null
name: LAMB
weight_decay_rate: 0.01
type: lamb
orbit_config:
eval_interval: 1000
eval_steps: -1
mode: train
steps_per_loop: 1000
total_steps: 825000
runtime:
distribution_strategy: 'tpu'
student_model:
cls_heads: [{'activation': 'tanh',
'cls_token_idx': 0,
'dropout_rate': 0.0,
'inner_dim': 512,
'name': 'next_sentence',
'num_classes': 2}]
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: relu
hidden_dropout_prob: 0.0
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 512
intra_bottleneck_size: 128
key_query_shared_bottleneck: true
max_sequence_length: 512
normalization_type: no_norm
num_attention_heads: 4
num_blocks: 24
num_feedforward_networks: 4
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: relu
mlm_initializer_range: 0.02
teacher_model:
cls_heads: []
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: gelu
hidden_dropout_prob: 0.1
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 4096
intra_bottleneck_size: 1024
key_query_shared_bottleneck: false
max_sequence_length: 512
normalization_type: layer_norm
num_attention_heads: 4
num_blocks: 24
num_feedforward_networks: 1
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: gelu
mlm_initializer_range: 0.02
teacher_model_init_checkpoint: gs://**/uncased_L-24_H-1024_B-512_A-4_teacher/tf2_checkpoint/bert_model.ckpt-1
student_model_init_checkpoint: ''
train_datasest:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord*,gs://**/seq_512_mask_20/books.tfrecord*
is_training: true
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
eval_dataset:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord-00141-of-00500,gs://**/seq_512_mask_20/books.tfrecord-00141-of-00500
is_training: false
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
layer_wise_distillation:
num_steps: 20000
warmup_steps: 0
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-3
decay_steps: 20000
end_to_end_distillation:
num_steps: 585000
warmup_steps: 20000
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-7
decay_steps: 585000
distill_ground_truth_ratio: 0.5
optimizer:
optimizer:
lamb:
beta_1: 0.9
beta_2: 0.999
clipnorm: 1.0
epsilon: 1.0e-06
exclude_from_layer_adaptation: null
exclude_from_weight_decay: ['LayerNorm', 'bias', 'norm']
global_clipnorm: null
name: LAMB
weight_decay_rate: 0.01
type: lamb
orbit_config:
eval_interval: 1000
eval_steps: -1
mode: train
steps_per_loop: 1000
total_steps: 825000
runtime:
distribution_strategy: 'tpu'
student_model:
cls_heads: [{'activation': 'tanh',
'cls_token_idx': 0,
'dropout_rate': 0.0,
'inner_dim': 512,
'name': 'next_sentence',
'num_classes': 2}]
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: relu
hidden_dropout_prob: 0.0
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 1024
intra_bottleneck_size: 256
key_query_shared_bottleneck: true
max_sequence_length: 512
normalization_type: no_norm
num_attention_heads: 4
num_blocks: 12
num_feedforward_networks: 6
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: relu
mlm_initializer_range: 0.02
teacher_model:
cls_heads: []
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: gelu
hidden_dropout_prob: 0.1
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 4096
intra_bottleneck_size: 1024
key_query_shared_bottleneck: false
max_sequence_length: 512
normalization_type: layer_norm
num_attention_heads: 4
num_blocks: 24
num_feedforward_networks: 1
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: gelu
mlm_initializer_range: 0.02
teacher_model_init_checkpoint: gs://**/uncased_L-24_H-1024_B-512_A-4_teacher/tf2_checkpoint/bert_model.ckpt-1
student_model_init_checkpoint: ''
train_datasest:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord*,gs://**/seq_512_mask_20/books.tfrecord*
is_training: true
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
eval_dataset:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord-00141-of-00500,gs://**/seq_512_mask_20/books.tfrecord-00141-of-00500
is_training: false
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
layer_wise_distillation:
num_steps: 20000
warmup_steps: 0
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-3
decay_steps: 20000
end_to_end_distillation:
num_steps: 585000
warmup_steps: 20000
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-7
decay_steps: 585000
distill_ground_truth_ratio: 0.5
optimizer:
optimizer:
lamb:
beta_1: 0.9
beta_2: 0.999
clipnorm: 1.0
epsilon: 1.0e-06
exclude_from_layer_adaptation: null
exclude_from_weight_decay: ['LayerNorm', 'bias', 'norm']
global_clipnorm: null
name: LAMB
weight_decay_rate: 0.01
type: lamb
orbit_config:
eval_interval: 1000
eval_steps: -1
mode: train
steps_per_loop: 1000
total_steps: 825000
runtime:
distribution_strategy: 'tpu'
student_model:
cls_heads: [{'activation': 'tanh',
'cls_token_idx': 0,
'dropout_rate': 0.0,
'inner_dim': 512,
'name': 'next_sentence',
'num_classes': 2}]
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: relu
hidden_dropout_prob: 0.0
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 1024
intra_bottleneck_size: 256
key_query_shared_bottleneck: true
max_sequence_length: 512
normalization_type: no_norm
num_attention_heads: 4
num_blocks: 12
num_feedforward_networks: 4
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: relu
mlm_initializer_range: 0.02
teacher_model:
cls_heads: []
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: gelu
hidden_dropout_prob: 0.1
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 4096
intra_bottleneck_size: 1024
key_query_shared_bottleneck: false
max_sequence_length: 512
normalization_type: layer_norm
num_attention_heads: 4
num_blocks: 24
num_feedforward_networks: 1
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: gelu
mlm_initializer_range: 0.02
teacher_model_init_checkpoint: gs://**/uncased_L-24_H-1024_B-512_A-4_teacher/tf2_checkpoint/bert_model.ckpt-1
student_model_init_checkpoint: ''
train_datasest:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord*,gs://**/seq_512_mask_20/books.tfrecord*
is_training: true
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
eval_dataset:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord-00141-of-00500,gs://**/seq_512_mask_20/books.tfrecord-00141-of-00500
is_training: false
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
layer_wise_distillation:
num_steps: 30000
warmup_steps: 0
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-3
decay_steps: 30000
end_to_end_distillation:
num_steps: 585000
warmup_steps: 20000
initial_learning_rate: 1.5e-3
end_learning_rate: 1.5e-7
decay_steps: 585000
distill_ground_truth_ratio: 0.5
optimizer:
optimizer:
lamb:
beta_1: 0.9
beta_2: 0.999
clipnorm: 1.0
epsilon: 1.0e-06
exclude_from_layer_adaptation: null
exclude_from_weight_decay: ['LayerNorm', 'bias', 'norm']
global_clipnorm: null
name: LAMB
weight_decay_rate: 0.01
type: lamb
orbit_config:
eval_interval: 1000
eval_steps: -1
mode: train
steps_per_loop: 1000
total_steps: 825000
runtime:
distribution_strategy: 'tpu'
student_model:
cls_heads: [{'activation': 'tanh',
'cls_token_idx': 0,
'dropout_rate': 0.0,
'inner_dim': 512,
'name': 'next_sentence',
'num_classes': 2}]
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: relu
hidden_dropout_prob: 0.0
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 1024
intra_bottleneck_size: 256
key_query_shared_bottleneck: true
max_sequence_length: 512
normalization_type: no_norm
num_attention_heads: 4
num_blocks: 8
num_feedforward_networks: 4
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: relu
mlm_initializer_range: 0.02
teacher_model:
cls_heads: []
encoder:
mobilebert:
attention_probs_dropout_prob: 0.1
classifier_activation: false
hidden_activation: gelu
hidden_dropout_prob: 0.1
hidden_size: 512
initializer_range: 0.02
input_mask_dtype: int32
intermediate_size: 4096
intra_bottleneck_size: 1024
key_query_shared_bottleneck: false
max_sequence_length: 512
normalization_type: layer_norm
num_attention_heads: 4
num_blocks: 24
num_feedforward_networks: 1
type_vocab_size: 2
use_bottleneck_attention: false
word_embed_size: 128
word_vocab_size: 30522
type: mobilebert
mlm_activation: gelu
mlm_initializer_range: 0.02
teacher_model_init_checkpoint: gs://**/uncased_L-24_H-1024_B-512_A-4_teacher/tf2_checkpoint/bert_model.ckpt-1
student_model_init_checkpoint: ''
train_datasest:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord*,gs://**/seq_512_mask_20/books.tfrecord*
is_training: true
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
eval_dataset:
block_length: 1
cache: false
cycle_length: null
deterministic: null
drop_remainder: true
enable_tf_data_service: false
global_batch_size: 2048
input_path: gs://**/seq_512_mask_20/wikipedia.tfrecord-00141-of-00500,gs://**/seq_512_mask_20/books.tfrecord-00141-of-00500
is_training: false
max_predictions_per_seq: 20
seq_length: 512
sharding: true
shuffle_buffer_size: 100
tf_data_service_address: null
tf_data_service_job_name: null
tfds_as_supervised: false
tfds_data_dir: ''
tfds_name: ''
tfds_skip_decoding_feature: ''
tfds_split: ''
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: false
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distillation trainer for EdgeTPU-BERT."""
import enum
import os
from typing import Optional
from absl import logging
import orbit
import tensorflow as tf
from official.modeling import optimization
from official.nlp import modeling
from official.nlp.data import data_loader_factory
from official.projects.edgetpu.nlp.configs import params
class DistillationMode(enum.Enum):
"""enum.Enum class for different distillation mode.
A state machine is used to control the training progress. When the training
job starts from the beginning or resumes from a preemption, the state is INIT.
Then depends on the 'self.current_step', the state switches to either
'LAYER_WISE' or 'END2END'.
Options:
UNKNOWN: Unknown status, always raise errors.
INIT: The trainer is initialized or restarted from the preemption.
LAYER_WISE: Layer-wise distillation for each transformer layers.
END2END: End-to-end distillation after layer-wise distillaiton is done.
"""
UNKNOWN = 0
INIT = 1
LAYER_WISE = 2
END2END = 3
def _get_distribution_losses(teacher, student):
"""Returns the beta and gamma distall losses for feature distribution."""
teacher_mean = tf.math.reduce_mean(teacher, axis=-1, keepdims=True)
student_mean = tf.math.reduce_mean(student, axis=-1, keepdims=True)
teacher_var = tf.math.reduce_variance(teacher, axis=-1, keepdims=True)
student_var = tf.math.reduce_variance(student, axis=-1, keepdims=True)
beta_loss = tf.math.squared_difference(student_mean, teacher_mean)
beta_loss = tf.math.reduce_mean(beta_loss, axis=None, keepdims=False)
gamma_loss = tf.math.abs(student_var - teacher_var)
gamma_loss = tf.math.reduce_mean(gamma_loss, axis=None, keepdims=False)
return beta_loss, gamma_loss
def _get_attention_loss(teacher_score, student_score):
"""Function to calculate attention loss for transformer layers."""
# Note that the definition of KLDivergence here is a little different from
# the original one (tf.keras.losses.KLDivergence). We adopt this approach
# to stay consistent with the TF1 implementation.
teacher_weight = tf.keras.activations.softmax(teacher_score, axis=-1)
student_log_weight = tf.nn.log_softmax(student_score, axis=-1)
kl_divergence = -(teacher_weight * student_log_weight)
kl_divergence = tf.math.reduce_sum(kl_divergence, axis=-1, keepdims=True)
kl_divergence = tf.math.reduce_mean(kl_divergence, axis=None,
keepdims=False)
return kl_divergence
def _build_sub_encoder(encoder, stage_number):
"""Builds a partial model containing the first few transformer layers."""
input_ids = encoder.inputs[0]
input_mask = encoder.inputs[1]
type_ids = encoder.inputs[2]
attention_mask = modeling.layers.SelfAttentionMask()(
inputs=input_ids, to_mask=input_mask)
embedding_output = encoder.embedding_layer(input_ids, type_ids)
layer_output = embedding_output
attention_score = None
for layer_idx in range(stage_number + 1):
layer_output, attention_score = encoder.transformer_layers[layer_idx](
layer_output, attention_mask, return_attention_scores=True)
return tf.keras.Model(
inputs=[input_ids, input_mask, type_ids],
outputs=[layer_output, attention_score])
class MobileBERTEdgeTPUDistillationTrainer(orbit.StandardTrainer,
orbit.StandardEvaluator):
"""Orbit based distillation training pipeline for MobileBERT-EdgeTPU models."""
def __init__(self,
teacher_model: modeling.models.BertPretrainerV2,
student_model: modeling.models.BertPretrainerV2,
strategy: tf.distribute.Strategy,
experiment_params: params.EdgeTPUBERTCustomParams,
export_ckpt_path: Optional[str] = None,
reuse_teacher_embedding: Optional[bool] = True):
self.teacher_model = teacher_model
self.student_model = student_model
self.strategy = strategy
self.layer_wise_distill_config = experiment_params.layer_wise_distillation
self.e2e_distill_config = experiment_params.end_to_end_distillation
self.optimizer_config = experiment_params.optimizer
self.train_dataset_config = experiment_params.train_datasest
self.eval_dataset_config = experiment_params.eval_dataset
self.word_vocab_size = experiment_params.student_model.encoder.mobilebert.word_vocab_size
self.distill_gt_ratio = experiment_params.end_to_end_distillation.distill_ground_truth_ratio
self.teacher_transformer_layers = experiment_params.teacher_model.encoder.mobilebert.num_blocks
self.student_transformer_layers = experiment_params.student_model.encoder.mobilebert.num_blocks
self.exported_ckpt_path = export_ckpt_path
self.current_step = orbit.utils.create_global_step()
self.current_step.assign(0)
# Stage is updated every time when the distillation is done for one
# transformer layer. self.stage is updated at the train_loop_begin()
# function. After the last stage is done, the self.mode is changed to
# 'e2e'.
self.stage = 0
self.mode = DistillationMode.INIT
# Number of transformer layers in teacher should be equal (or divisible)
# by the number of transformer layers in student.
if self.teacher_transformer_layers % self.student_transformer_layers != 0:
raise ValueError(
'Number of transformer layer must be equal or divisible.')
self.ratio = (self.teacher_transformer_layers //
self.student_transformer_layers)
# Create optimizers for different training stage.
self.layer_wise_optimizer = self.build_optimizer(
self.layer_wise_distill_config)
self.e2e_optimizer = self.build_optimizer(self.e2e_distill_config)
self.current_optimizer = self.layer_wise_optimizer
# A non-trainable layer for feature normalization for transfer loss.
self._layer_norm = tf.keras.layers.LayerNormalization(
axis=-1,
beta_initializer='zeros',
gamma_initializer='ones',
trainable=False)
self.build_dataset()
self.build_metrics()
# Create an empty exported checkpoint manager, it will be initialized once
# the training mode enters END2END.
self.exported_ckpt_manager = None
# Reuse the teacher's embedding table in student model.
if reuse_teacher_embedding:
logging.info('Copy word embedding from teacher model to student.')
teacher_encoder = self.teacher_model.encoder_network
student_encoder = self.student_model.encoder_network
embedding_weights = teacher_encoder.embedding_layer.get_weights()
student_encoder.embedding_layer.set_weights(embedding_weights)
orbit.StandardTrainer.__init__(self, self.train_dataset)
orbit.StandardEvaluator.__init__(self, self.eval_dataset)
def build_dataset(self):
"""Creates the training and evaluation dataset."""
# Returns None when the input_path is 'dummy'.
if self.train_dataset_config.input_path == 'dummy':
self.train_dataset = None
self.eval_dataset = None
return
# None distributed dataset.
train_dataset = data_loader_factory.get_data_loader(
self.train_dataset_config).load()
eval_dataset = data_loader_factory.get_data_loader(
self.eval_dataset_config).load()
# Ddistributed dataset.
self.train_dataset = orbit.utils.make_distributed_dataset(
self.strategy, train_dataset)
self.eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy, eval_dataset)
def build_model(self):
"""Creates the fused model from teacher/student model."""
self.teacher_model.trainable = False
if self.mode == DistillationMode.LAYER_WISE:
# Build a model that outputs teacher's and student's transformer outputs.
inputs = self.student_model.encoder_network.inputs
student_sub_encoder = _build_sub_encoder(
encoder=self.student_model.encoder_network,
stage_number=self.stage)
student_output_feature, student_attention_score = student_sub_encoder(
inputs)
teacher_sub_encoder = _build_sub_encoder(
encoder=self.teacher_model.encoder_network,
stage_number=int(self.stage * self.ratio))
teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
inputs)
return tf.keras.Model(
inputs=inputs,
outputs=dict(
student_output_feature=student_output_feature,
student_attention_score=student_attention_score,
teacher_output_feature=teacher_output_feature,
teacher_attention_score=teacher_attention_score))
elif self.mode == DistillationMode.END2END:
# Build a model that outputs teacher's and student's MLM/NSP outputs.
inputs = self.student_model.inputs
student_pretrainer_outputs = self.student_model(inputs)
teacher_pretrainer_outputs = self.teacher_model(inputs)
model = tf.keras.Model(
inputs=inputs,
outputs=dict(
student_pretrainer_outputs=student_pretrainer_outputs,
teacher_pretrainer_outputs=teacher_pretrainer_outputs,
))
# Checkpoint the student encoder which is the goal of distillation.
model.checkpoint_items = self.student_model.checkpoint_items
return model
else:
raise ValueError(f'Unknown distillation mode: {self.mode}.')
def build_optimizer(self, config):
"""Creates optimier for the fused model."""
optimizer_config = self.optimizer_config.replace(
learning_rate={
'polynomial': {
'decay_steps': config.decay_steps,
'initial_learning_rate': config.initial_learning_rate,
'end_learning_rate': config.end_learning_rate,
}
},
warmup={
'type': 'linear',
'linear': {
'warmup_steps': config.warmup_steps,
}
})
logging.info('The optimizer config is: %s', optimizer_config.as_dict())
optimizer_factory = optimization.OptimizerFactory(optimizer_config)
return optimizer_factory.build_optimizer(
optimizer_factory.build_learning_rate())
def build_metrics(self):
"""Creates metrics functions for the training."""
self.train_metrics = {
'feature_transfer_mse': tf.keras.metrics.Mean(),
'beta_transfer_loss': tf.keras.metrics.Mean(),
'gamma_transfer_loss': tf.keras.metrics.Mean(),
'attention_transfer_loss': tf.keras.metrics.Mean(),
'masked_lm_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'lm_example_loss': tf.keras.metrics.Mean(),
'total_loss': tf.keras.metrics.Mean(),
'next_sentence_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'next_sentence_loss': tf.keras.metrics.Mean(),
}
self.eval_metrics = {
'masked_lm_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'next_sentence_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
}
def build_exported_ckpt_manager(self):
"""Creates checkpoint manager for exported models."""
if self.exported_ckpt_path is None:
logging.warn('exported_ckpt_path is not specified. The saved model'
'can not be used for downstreaming tasks.')
return
checkpoint = tf.train.Checkpoint(global_step=self.current_step,
model=self.model,
optimizer=self.current_optimizer,
**self.model.checkpoint_items)
self.exported_ckpt_manager = tf.train.CheckpointManager(
checkpoint,
directory=os.path.join(self.exported_ckpt_path, 'exported_ckpt'),
max_to_keep=2,
step_counter=self.current_step,
checkpoint_interval=20000,
init_fn=None)
def calculate_loss_metrics(self, labels, outputs):
"""Calculates loss and metrics.
Args:
labels: Ground truth from dataset.
outputs: fused outputs from teacher model and student model.
Returns:
total loss value.
"""
if self.mode == DistillationMode.LAYER_WISE:
teacher_feature = outputs['teacher_output_feature']
student_feature = outputs['student_output_feature']
feature_transfer_loss = tf.keras.losses.mean_squared_error(
self._layer_norm(teacher_feature), self._layer_norm(student_feature))
# feature_transfer_loss = tf.reduce_mean(feature_transfer_loss)
feature_transfer_loss *= self.layer_wise_distill_config.hidden_distill_factor
beta_loss, gamma_loss = _get_distribution_losses(teacher_feature,
student_feature)
beta_loss *= self.layer_wise_distill_config.beta_distill_factor
gamma_loss *= self.layer_wise_distill_config.gamma_distill_factor
total_loss = feature_transfer_loss + beta_loss + gamma_loss
teacher_attention = outputs['teacher_attention_score']
student_attention = outputs['student_attention_score']
attention_loss = _get_attention_loss(teacher_attention, student_attention)
attention_loss *= self.layer_wise_distill_config.attention_distill_factor
total_loss += attention_loss
total_loss /= tf.cast((self.stage + 1), tf.float32)
elif self.mode == DistillationMode.END2END:
lm_label = labels['masked_lm_ids']
# Shape: [batch, max_predictions_per_seq, word_vocab_size]
lm_label = tf.one_hot(indices=lm_label,
depth=self.word_vocab_size,
on_value=1.0,
off_value=0.0,
axis=-1,
dtype=tf.float32)
lm_label_weights = labels['masked_lm_weights']
teacher_mlm_logits = outputs['teacher_pretrainer_outputs']['mlm_logits']
teacher_labels = tf.nn.softmax(teacher_mlm_logits, axis=-1)
gt_label = self.distill_gt_ratio * lm_label
teacher_label = (1 - self.distill_gt_ratio) * teacher_labels
lm_label = gt_label + teacher_label
student_pretrainer_output = outputs['student_pretrainer_outputs']
# Shape: [batch, max_predictions_per_seq, word_vocab_size]
student_lm_log_probs = tf.nn.log_softmax(
student_pretrainer_output['mlm_logits'], axis=-1)
# Shape: [batch * max_predictions_per_seq]
per_example_loss = tf.reshape(
-tf.reduce_sum(student_lm_log_probs * lm_label, axis=[-1]), [-1])
lm_label_weights = tf.reshape(labels['masked_lm_weights'], [-1])
lm_numerator_loss = tf.reduce_sum(per_example_loss * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
total_loss = mlm_loss
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
student_pretrainer_output['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
total_loss += sentence_loss
else:
raise ValueError('Training mode has to be LAYER-WISE or END2END.')
if self.mode == DistillationMode.LAYER_WISE:
self.train_metrics['feature_transfer_mse'].update_state(
feature_transfer_loss)
self.train_metrics['beta_transfer_loss'].update_state(beta_loss)
self.train_metrics['gamma_transfer_loss'].update_state(gamma_loss)
self.train_metrics['attention_transfer_loss'].update_state(attention_loss)
elif self.mode == DistillationMode.END2END:
self.train_metrics['lm_example_loss'].update_state(mlm_loss)
self.train_metrics['next_sentence_loss'].update_state(sentence_loss)
self.train_metrics['total_loss'].update_state(total_loss)
return total_loss
def calculate_accuracy_metrics(self, labels, outputs, metrics):
"""Calculates metrics that are not related to the losses."""
if self.mode == DistillationMode.END2END:
student_pretrainer_output = outputs['student_pretrainer_outputs']
metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'],
student_pretrainer_output['mlm_logits'],
labels['masked_lm_weights'])
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'],
student_pretrainer_output['next_sentence'])
def _rebuild_training_graph(self):
"""Rebuilds the training graph when one stage/step is done."""
self.stage = (self.current_step.numpy() //
self.layer_wise_distill_config.num_steps)
logging.info('Start distillation training for the %d stage', self.stage)
self.model = self.build_model()
self.layer_wise_optimizer = self.build_optimizer(
self.layer_wise_distill_config)
# Rebuild the dataset which can significantly improve the training
# accuracy.
logging.info('Rebuild the training dataset.')
self.build_dataset()
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated loss function.
logging.info('Rebuild the training and evaluation graph.')
self._train_loop_fn = None
self._eval_loop_fn = None
def train_loop_begin(self):
"""A train loop is similar with the concept of an epoch."""
self.train_metrics['feature_transfer_mse'].reset_states()
self.train_metrics['beta_transfer_loss'].reset_states()
self.train_metrics['gamma_transfer_loss'].reset_states()
self.train_metrics['attention_transfer_loss'].reset_states()
self.train_metrics['total_loss'].reset_states()
self.train_metrics['lm_example_loss'].reset_states()
self.train_metrics['next_sentence_loss'].reset_states()
self.train_metrics['masked_lm_accuracy'].reset_states()
self.train_metrics['next_sentence_accuracy'].reset_states()
if self.mode == DistillationMode.INIT:
if (self.current_step.numpy() < self.layer_wise_distill_config.num_steps *
self.student_transformer_layers):
logging.info('Start or resume layer-wise training.')
self.mode = DistillationMode.LAYER_WISE
self.stage = (self.current_step.numpy() //
self.layer_wise_distill_config.num_steps)
self.model = self.build_model()
self.build_dataset()
self.current_optimizer = self.layer_wise_optimizer
else:
self.mode = DistillationMode.END2END
logging.info('Start or resume e2e training.')
self.model = self.build_model()
self.current_optimizer = self.e2e_optimizer
elif self.mode == DistillationMode.LAYER_WISE:
if (self.current_step.numpy() < self.layer_wise_distill_config.num_steps *
self.student_transformer_layers):
if (self.current_step.numpy() %
self.layer_wise_distill_config.num_steps) == 0:
self._rebuild_training_graph()
self.current_optimizer = self.layer_wise_optimizer
else:
self.mode = DistillationMode.END2END
self.model = self.build_model()
logging.info('Start e2e distillation training.')
self.current_optimizer = self.e2e_optimizer
logging.info('Rebuild the training dataset.')
self.build_dataset()
logging.info('Rebuild the training and evaluation graph.')
self._train_loop_fn = None
self._eval_loop_fn = None
def train_step(self, iterator):
"""A single step of train."""
def step_fn(inputs):
with tf.GradientTape() as tape:
outputs = self.model(inputs, training=True)
loss = self.calculate_loss_metrics(inputs, outputs)
self.calculate_accuracy_metrics(inputs, outputs, self.train_metrics)
grads = tape.gradient(loss, self.model.trainable_variables)
self.current_optimizer.apply_gradients(
zip(grads, self.model.trainable_variables))
self.current_step.assign_add(1)
self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self):
"""A train loop is similar with the concept of an epoch."""
if self.mode == DistillationMode.END2END:
# Save the exported checkpoint (used for downstreaming tasks) after every
# 'checkpoint_interval' steps. And only export checkpoints after entering
# e2e distillation training stage.
if self.exported_ckpt_manager is None:
self.build_exported_ckpt_manager()
self.exported_ckpt_manager.save(
checkpoint_number=self.current_step.numpy(),
check_interval=True)
return {
'feature_transfer_mse':
self.train_metrics['feature_transfer_mse'].result(),
'beta_transfer_loss':
self.train_metrics['beta_transfer_loss'].result(),
'gamma_transfer_loss':
self.train_metrics['gamma_transfer_loss'].result(),
'attention_transfer_loss':
self.train_metrics['attention_transfer_loss'].result(),
'total_loss':
self.train_metrics['total_loss'].result(),
'lm_example_loss':
self.train_metrics['lm_example_loss'].result(),
'next_sentence_loss':
self.train_metrics['next_sentence_loss'].result(),
'masked_lm_accuracy':
self.train_metrics['masked_lm_accuracy'].result(),
'next_sentence_accuracy':
self.train_metrics['next_sentence_accuracy'].result(),
'learning_rate':
self.current_optimizer.learning_rate(
self.current_optimizer.iterations),
'current_step':
self.current_step,
'optimizer_step':
self.current_optimizer.iterations,
}
# TODO(longy): We only run evaluation on downstream tasks.
def eval_begin(self):
self.eval_metrics['masked_lm_accuracy'].reset_states()
self.eval_metrics['next_sentence_accuracy'].reset_states()
def eval_step(self, iterator):
def step_fn(inputs):
outputs = self.model(inputs, training=False)
self.calculate_accuracy_metrics(inputs, outputs, self.eval_metrics)
self.strategy.run(step_fn, args=(next(iterator),))
def eval_end(self):
return {'masked_lm_accuracy':
self.eval_metrics['masked_lm_accuracy'].result(),
'next_sentence_accuracy':
self.eval_metrics['next_sentence_accuracy'].result()}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mobilebert_edgetpu_trainer.py."""
import tensorflow as tf
from official.projects.edgetpu.nlp import mobilebert_edgetpu_trainer
from official.projects.edgetpu.nlp.configs import params
from official.projects.edgetpu.nlp.modeling import model_builder
# Helper function to create dummy dataset
def _dummy_dataset():
def dummy_data(_):
dummy_ids = tf.zeros((1, 64), dtype=tf.int32)
dummy_lm = tf.zeros((1, 64), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class EdgetpuBertTrainerTest(tf.test.TestCase):
def setUp(self):
super(EdgetpuBertTrainerTest, self).setUp()
config_path = 'third_party/tensorflow_models/official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_m.yaml'
self.experiment_params = params.EdgeTPUBERTCustomParams.from_yaml(
config_path)
self.strategy = tf.distribute.get_strategy()
self.experiment_params.train_datasest.input_path = 'dummy'
self.experiment_params.eval_dataset.input_path = 'dummy'
def test_train_model_locally(self):
"""Tests training a model locally with one step."""
teacher_model = model_builder.build_bert_pretrainer(
pretrainer_cfg=self.experiment_params.teacher_model,
name='teacher')
_ = teacher_model(teacher_model.inputs)
student_model = model_builder.build_bert_pretrainer(
pretrainer_cfg=self.experiment_params.student_model,
name='student')
_ = student_model(student_model.inputs)
trainer = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer(
teacher_model=teacher_model,
student_model=student_model,
strategy=self.strategy,
experiment_params=self.experiment_params)
# Rebuild dummy dataset since loading real dataset will cause timeout error.
trainer.train_dataset = _dummy_dataset()
trainer.eval_dataset = _dummy_dataset()
train_dataset_iter = iter(trainer.train_dataset)
eval_dataset_iter = iter(trainer.eval_dataset)
trainer.train_loop_begin()
trainer.train_step(train_dataset_iter)
trainer.eval_step(eval_dataset_iter)
if __name__ == '__main__':
tf.test.main()
...@@ -12,8 +12,3 @@ ...@@ -12,8 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Keras-based one-hot embedding layer."""
from official.nlp.modeling import layers
OnDeviceEmbedding = layers.OnDeviceEmbedding
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment