Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7687b1d3
Commit
7687b1d3
authored
Apr 01, 2021
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 01, 2021
Browse files
Open source mobilebert project.
PiperOrigin-RevId: 366313554
parent
41f71f6c
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1642 additions
and
0 deletions
+1642
-0
official/nlp/projects/mobilebert/README.md
official/nlp/projects/mobilebert/README.md
+70
-0
official/nlp/projects/mobilebert/__init__.py
official/nlp/projects/mobilebert/__init__.py
+14
-0
official/nlp/projects/mobilebert/distillation.py
official/nlp/projects/mobilebert/distillation.py
+557
-0
official/nlp/projects/mobilebert/distillation_test.py
official/nlp/projects/mobilebert/distillation_test.py
+166
-0
official/nlp/projects/mobilebert/experiments/en_uncased_student.yaml
...p/projects/mobilebert/experiments/en_uncased_student.yaml
+22
-0
official/nlp/projects/mobilebert/experiments/en_uncased_teacher.yaml
...p/projects/mobilebert/experiments/en_uncased_teacher.yaml
+22
-0
official/nlp/projects/mobilebert/experiments/mobilebert_distillation_en_uncased.yaml
...ebert/experiments/mobilebert_distillation_en_uncased.yaml
+79
-0
official/nlp/projects/mobilebert/export_tfhub.py
official/nlp/projects/mobilebert/export_tfhub.py
+86
-0
official/nlp/projects/mobilebert/model_utils.py
official/nlp/projects/mobilebert/model_utils.py
+170
-0
official/nlp/projects/mobilebert/run_distillation.py
official/nlp/projects/mobilebert/run_distillation.py
+149
-0
official/nlp/projects/mobilebert/tf2_model_checkpoint_converter.py
...nlp/projects/mobilebert/tf2_model_checkpoint_converter.py
+278
-0
official/nlp/projects/mobilebert/utils.py
official/nlp/projects/mobilebert/utils.py
+29
-0
No files found.
official/nlp/projects/mobilebert/README.md
0 → 100644
View file @
7687b1d3
# MobileBERT (MobileBERT: A Compact Task-Agnostic BERT for Resource-Limited Devices)
[
MobileBERT
](
https://arxiv.org/abs/2004.02984
)
is a thin version of BERT_LARGE, while equipped with bottleneck
structures and a carefully designed balance between self-attentions and
feed-forward networks.
To train MobileBERT, we first train a specially designed teacher model, an
inverted-bottleneck incorporated BERT_LARGE model. Then, we conduct knowledge
transfer from this teacher to MobileBERT. Empirical studies show that MobileBERT
is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive
results on well-known benchmarks. This repository contains TensorFlow 2.x
implementation for MobileBERT.
## Network Implementations
Following
[
MobileBERT TF1 implementation
](
https://github.com/google-research/google-research/tree/master/mobilebert
)
,
we re-implemented MobileBERT encoder and layers using
`tf.keras`
APIs in NLP
modeling library:
*
[
mobile_bert_encoder.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py
)
contains
`MobileBERTEncoder`
implementation.
*
[
mobile_bert_layers.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py
)
contains
`MobileBertEmbedding`
,
`MobileBertMaskedLM`
and
`MobileBertMaskedLM`
implementation.
## Pre-trained Models
We converted the originial TF 1.x pretrained English MobileBERT checkpoint to
TF 2.x checkpoint, which is compatible with the above implementations.
In addition, we also provide new multiple-lingual MobileBERT checkpoint
trained using multi-lingual Wiki data. Furthermore, we export the checkpoints to
TF-HUB SavedModel. Please find the details in the following table:
Model | Configuration | Number of Parameters | Training Data | Checkpoint & Vocabulary | TF-Hub SavedModel | Metrics
------------------------------ | :--------------------------------------: | :------------------- | :-----------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :-----:
MobileBERT uncased English | uncased_L-24_H-128_B-512_A-4_F-4_OPT | 25.3 Million | Wiki + Books |
[
Download
](
https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz
)
|
[
TF-Hub
](
https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1
)
| Squad v1.1 F1 90.0, GLUE 77.7
MobileBERT cased Multi-lingual | multi_cased_L-24_H-128_B-512_A-4_F-4_OPT | 36 Million | Wiki |
[
Download
](
https://storage.cloud.google.com/model_garden_artifacts/official/mobilebert/multi_cased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz
)
|
[
TF-Hub
](
https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1
)
| XNLI (zero-short):64.7
### Restoring from Checkpoints
To load the pre-trained MobileBERT checkpoint in your code, please follow the
example below:
```
python
import
tensorflow
as
tf
from
official.nlp.projects.mobilebert
import
model_utils
bert_config_file
=
...
model_checkpoint_path
=
...
bert_config
=
model_utils
.
BertConfig
.
from_json_file
(
bert_config_file
)
# `pretrainer` is an instance of `nlp.modeling.models.BertPretrainerV2`.
pretrainer
=
model_utils
.
create_mobilebert_pretrainer
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
**
pretrainer
.
checkpoint_items
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
# `mobilebert_encoder` is an instance of
# `nlp.modeling.networks.MobileBERTEncoder`.
mobilebert_encoder
=
pretrainer
.
encoder_network
```
### Use TF-Hub models
For the usage of MobileBert TF-Hub model, please see the TF-Hub site
(
[
English model
](
https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1
)
or
[
Multilingual model
](
https://tfhub.dev/tensorflow/mobilebert_multi_cased_L-24_H-128_B-512_A-4_F-4_OPT/1
)
).
official/nlp/projects/mobilebert/__init__.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/projects/mobilebert/distillation.py
0 → 100644
View file @
7687b1d3
This diff is collapsed.
Click to expand it.
official/nlp/projects/mobilebert/distillation_test.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.mobilebert.distillation."""
import
os
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
tf_utils
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
from
official.nlp.projects.mobilebert
import
distillation
class
DistillationTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
DistillationTest
,
self
).
setUp
()
# using small model for testing
self
.
model_block_num
=
2
self
.
task_config
=
distillation
.
BertDistillationTaskConfig
(
teacher_model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'mobilebert'
,
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
num_blocks
=
self
.
model_block_num
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
256
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
],
mlm_activation
=
'gelu'
),
student_model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'mobilebert'
,
mobilebert
=
encoders
.
MobileBertEncoderConfig
(
num_blocks
=
self
.
model_block_num
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
256
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
],
mlm_activation
=
'relu'
),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
'dummy'
,
max_predictions_per_seq
=
76
,
seq_length
=
512
,
global_batch_size
=
10
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
'dummy'
,
max_predictions_per_seq
=
76
,
seq_length
=
512
,
global_batch_size
=
10
))
# set only 1 step for each stage
progressive_config
=
distillation
.
BertDistillationProgressiveConfig
()
progressive_config
.
layer_wise_distill_config
.
num_steps
=
1
progressive_config
.
pretrain_distill_config
.
num_steps
=
1
optimization_config
=
optimization
.
OptimizationConfig
(
optimizer
=
optimization
.
OptimizerConfig
(
type
=
'lamb'
,
lamb
=
optimization
.
LAMBConfig
(
weight_decay_rate
=
0.0001
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'layer_norm'
,
'bias'
,
'no_norm'
])),
learning_rate
=
optimization
.
LrConfig
(
type
=
'polynomial'
,
polynomial
=
optimization
.
PolynomialLrConfig
(
initial_learning_rate
=
1.5e-3
,
decay_steps
=
10000
,
end_learning_rate
=
1.5e-3
)),
warmup
=
optimization
.
WarmupConfig
(
type
=
'linear'
,
linear
=
optimization
.
LinearWarmupConfig
(
warmup_learning_rate
=
0
)))
self
.
exp_config
=
cfg
.
ExperimentConfig
(
task
=
self
.
task_config
,
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(
progressive
=
progressive_config
,
optimizer_config
=
optimization_config
))
# Create a teacher model checkpoint.
teacher_encoder
=
encoders
.
build_encoder
(
self
.
task_config
.
teacher_model
.
encoder
)
pretrainer_config
=
self
.
task_config
.
teacher_model
if
pretrainer_config
.
cls_heads
:
teacher_cls_heads
=
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
pretrainer_config
.
cls_heads
]
else
:
teacher_cls_heads
=
[]
masked_lm
=
layers
.
MobileBertMaskedLM
(
embedding_table
=
teacher_encoder
.
get_embedding_table
(),
activation
=
tf_utils
.
get_activation
(
pretrainer_config
.
mlm_activation
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
pretrainer_config
.
mlm_initializer_range
),
name
=
'cls/predictions'
)
teacher_pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
teacher_encoder
,
classification_heads
=
teacher_cls_heads
,
customized_masked_lm
=
masked_lm
)
# The model variables will be created after the forward call.
_
=
teacher_pretrainer
(
teacher_pretrainer
.
inputs
)
teacher_pretrainer_ckpt
=
tf
.
train
.
Checkpoint
(
**
teacher_pretrainer
.
checkpoint_items
)
teacher_ckpt_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'teacher_model.ckpt'
)
teacher_pretrainer_ckpt
.
save
(
teacher_ckpt_path
)
self
.
task_config
.
teacher_model_init_checkpoint
=
self
.
get_temp_dir
()
def
test_task
(
self
):
bert_distillation_task
=
distillation
.
BertDistillationTask
(
strategy
=
tf
.
distribute
.
get_strategy
(),
progressive
=
self
.
exp_config
.
trainer
.
progressive
,
optimizer_config
=
self
.
exp_config
.
trainer
.
optimizer_config
,
task_config
=
self
.
task_config
)
metrics
=
bert_distillation_task
.
build_metrics
()
train_dataset
=
bert_distillation_task
.
get_train_dataset
(
stage_id
=
0
)
train_iterator
=
iter
(
train_dataset
)
eval_dataset
=
bert_distillation_task
.
get_eval_dataset
(
stage_id
=
0
)
eval_iterator
=
iter
(
eval_dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
# test train/val step for all stages, including the last pretraining stage
for
stage
in
range
(
self
.
model_block_num
+
1
):
step
=
stage
bert_distillation_task
.
update_pt_stage
(
step
)
model
=
bert_distillation_task
.
get_model
(
stage
,
None
)
bert_distillation_task
.
initialize
(
model
)
bert_distillation_task
.
train_step
(
next
(
train_iterator
),
model
,
optimizer
,
metrics
=
metrics
)
bert_distillation_task
.
validation_step
(
next
(
eval_iterator
),
model
,
metrics
=
metrics
)
logging
.
info
(
'begin to save and load model checkpoint'
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt
.
save
(
self
.
get_temp_dir
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/mobilebert/experiments/en_uncased_student.yaml
0 → 100644
View file @
7687b1d3
task
:
model
:
encoder
:
type
:
mobilebert
mobilebert
:
word_vocab_size
:
30522
word_embed_size
:
128
type_vocab_size
:
2
max_sequence_length
:
512
num_blocks
:
24
hidden_size
:
512
num_attention_heads
:
4
intermediate_size
:
512
hidden_activation
:
relu
hidden_dropout_prob
:
0.0
attention_probs_dropout_prob
:
0.1
intra_bottleneck_size
:
128
initializer_range
:
0.02
key_query_shared_bottleneck
:
true
num_feedforward_networks
:
4
normalization_type
:
no_norm
classifier_activation
:
false
official/nlp/projects/mobilebert/experiments/en_uncased_teacher.yaml
0 → 100644
View file @
7687b1d3
task
:
model
:
encoder
:
type
:
mobilebert
mobilebert
:
word_vocab_size
:
30522
word_embed_size
:
128
type_vocab_size
:
2
max_sequence_length
:
512
num_blocks
:
24
hidden_size
:
512
num_attention_heads
:
4
intermediate_size
:
4096
hidden_activation
:
gelu
hidden_dropout_prob
:
0.1
attention_probs_dropout_prob
:
0.1
intra_bottleneck_size
:
1024
initializer_range
:
0.02
key_query_shared_bottleneck
:
false
num_feedforward_networks
:
1
normalization_type
:
layer_norm
classifier_activation
:
false
official/nlp/projects/mobilebert/experiments/mobilebert_distillation_en_uncased.yaml
0 → 100644
View file @
7687b1d3
task
:
train_data
:
drop_remainder
:
true
global_batch_size
:
2048
input_path
:
"
"
is_training
:
true
max_predictions_per_seq
:
20
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
validation_data
:
drop_remainder
:
true
global_batch_size
:
2048
input_path
:
"
"
is_training
:
false
max_predictions_per_seq
:
20
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
teacher_model
:
cls_heads
:
[]
mlm_activation
:
gelu
mlm_initializer_range
:
0.02
encoder
:
type
:
mobilebert
mobilebert
:
word_vocab_size
:
30522
word_embed_size
:
128
type_vocab_size
:
2
max_sequence_length
:
512
num_blocks
:
24
hidden_size
:
512
num_attention_heads
:
4
intermediate_size
:
4096
hidden_activation
:
gelu
hidden_dropout_prob
:
0.1
attention_probs_dropout_prob
:
0.1
intra_bottleneck_size
:
1024
initializer_range
:
0.02
key_query_shared_bottleneck
:
false
num_feedforward_networks
:
1
normalization_type
:
layer_norm
classifier_activation
:
false
student_model
:
cls_heads
:
[{
activation
:
tanh
,
cls_token_idx
:
0
,
dropout_rate
:
0.0
,
inner_dim
:
512
,
name
:
next_sentence
,
num_classes
:
2
}]
mlm_activation
:
relu
mlm_initializer_range
:
0.02
encoder
:
type
:
mobilebert
mobilebert
:
word_vocab_size
:
30522
word_embed_size
:
128
type_vocab_size
:
2
max_sequence_length
:
512
num_blocks
:
24
hidden_size
:
512
num_attention_heads
:
4
intermediate_size
:
512
hidden_activation
:
relu
hidden_dropout_prob
:
0.0
attention_probs_dropout_prob
:
0.1
intra_bottleneck_size
:
128
initializer_range
:
0.02
key_query_shared_bottleneck
:
true
num_feedforward_networks
:
4
normalization_type
:
no_norm
classifier_activation
:
false
teacher_model_init_checkpoint
:
"
"
trainer
:
progressive
:
if_copy_embeddings
:
true
layer_wise_distill_config
:
num_steps
:
10000
pretrain_distill_config
:
num_steps
:
500000
decay_steps
:
500000
train_steps
:
740000
max_to_keep
:
10
official/nlp/projects/mobilebert/export_tfhub.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export the MobileBERT encoder model as a TF-Hub SavedModel."""
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.projects.mobilebert
import
model_utils
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core mobilebert layers."
)
flags
.
DEFINE_string
(
"model_checkpoint_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
True
,
"Whether to lowercase."
)
def
create_mobilebert_model
(
bert_config
):
"""Creates a model for exporting to tfhub."""
pretrainer
=
model_utils
.
create_mobilebert_pretrainer
(
bert_config
)
encoder
=
pretrainer
.
encoder_network
encoder_inputs_dict
=
{
x
.
name
:
x
for
x
in
encoder
.
inputs
}
encoder_output_dict
=
encoder
(
encoder_inputs_dict
)
# For interchangeability with other text representations,
# add "default" as an alias for MobileBERT's whole-input reptesentations.
encoder_output_dict
[
"default"
]
=
encoder_output_dict
[
"pooled_output"
]
core_model
=
tf
.
keras
.
Model
(
inputs
=
encoder_inputs_dict
,
outputs
=
encoder_output_dict
)
pretrainer_inputs_dict
=
{
x
.
name
:
x
for
x
in
pretrainer
.
inputs
}
pretrainer_output_dict
=
pretrainer
(
pretrainer_inputs_dict
)
mlm_model
=
tf
.
keras
.
Model
(
inputs
=
pretrainer_inputs_dict
,
outputs
=
pretrainer_output_dict
)
# Set `_auto_track_sub_layers` to False, so that the additional weights
# from `mlm` sub-object will not be included in the core model.
# TODO(b/169210253): Use public API after the bug is resolved.
core_model
.
_auto_track_sub_layers
=
False
# pylint: disable=protected-access
core_model
.
mlm
=
mlm_model
return
core_model
,
pretrainer
def
export_bert_tfhub
(
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
,
do_lower_case
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model
,
pretrainer
=
create_mobilebert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
**
pretrainer
.
checkpoint_items
)
logging
.
info
(
"Begin to load model"
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
logging
.
info
(
"Loading model finished"
)
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
logging
.
info
(
"Begin to save files for tfhub at %s"
,
hub_destination
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
logging
.
info
(
"tfhub files exported!"
)
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
bert_config
=
model_utils
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/nlp/projects/mobilebert/model_utils.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import
copy
import
json
import
tensorflow.compat.v1
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
class
BertConfig
(
object
):
"""Configuration for `BertModel`."""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
initializer_range
=
0.02
,
embedding_size
=
None
,
trigram_input
=
False
,
use_bottleneck
=
False
,
intra_bottleneck_size
=
None
,
use_bottleneck_attention
=
False
,
key_query_shared_bottleneck
=
False
,
num_feedforward_networks
=
1
,
normalization_type
=
"layer_norm"
,
classifier_activation
=
True
):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
embedding_size: The size of the token embedding.
trigram_input: Use a convolution of trigram as input.
use_bottleneck: Use the bottleneck/inverted-bottleneck structure in BERT.
intra_bottleneck_size: The hidden size in the bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation.
key_query_shared_bottleneck: Use the same linear transformation for
query&key in the bottleneck.
num_feedforward_networks: Number of FFNs in a block.
normalization_type: The normalization type in BERT.
classifier_activation: Using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
embedding_size
=
embedding_size
self
.
trigram_input
=
trigram_input
self
.
use_bottleneck
=
use_bottleneck
self
.
intra_bottleneck_size
=
intra_bottleneck_size
self
.
use_bottleneck_attention
=
use_bottleneck_attention
self
.
key_query_shared_bottleneck
=
key_query_shared_bottleneck
self
.
num_feedforward_networks
=
num_feedforward_networks
self
.
normalization_type
=
normalization_type
self
.
classifier_activation
=
classifier_activation
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config
=
BertConfig
(
vocab_size
=
None
)
for
(
key
,
value
)
in
json_object
.
items
():
config
.
__dict__
[
key
]
=
value
if
config
.
embedding_size
is
None
:
config
.
embedding_size
=
config
.
hidden_size
if
config
.
intra_bottleneck_size
is
None
:
config
.
intra_bottleneck_size
=
config
.
hidden_size
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `BertConfig` from a json file of parameters."""
with
tf
.
gfile
.
GFile
(
json_file
,
"r"
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
create_mobilebert_pretrainer
(
bert_config
):
"""Creates a BertPretrainerV2 that wraps MobileBERTEncoder model."""
mobilebert_encoder
=
networks
.
MobileBERTEncoder
(
word_vocab_size
=
bert_config
.
vocab_size
,
word_embed_size
=
bert_config
.
embedding_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
max_sequence_length
=
bert_config
.
max_position_embeddings
,
num_blocks
=
bert_config
.
num_hidden_layers
,
hidden_size
=
bert_config
.
hidden_size
,
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_act_fn
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
hidden_dropout_prob
=
bert_config
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
bert_config
.
attention_probs_dropout_prob
,
intra_bottleneck_size
=
bert_config
.
intra_bottleneck_size
,
initializer_range
=
bert_config
.
initializer_range
,
use_bottleneck_attention
=
bert_config
.
use_bottleneck_attention
,
key_query_shared_bottleneck
=
bert_config
.
key_query_shared_bottleneck
,
num_feedforward_networks
=
bert_config
.
num_feedforward_networks
,
normalization_type
=
bert_config
.
normalization_type
,
classifier_activation
=
bert_config
.
classifier_activation
)
masked_lm
=
layers
.
MobileBertMaskedLM
(
embedding_table
=
mobilebert_encoder
.
get_embedding_table
(),
activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
name
=
"cls/predictions"
)
pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
mobilebert_encoder
,
customized_masked_lm
=
masked_lm
)
# Makes sure the pretrainer variables are created.
_
=
pretrainer
(
pretrainer
.
inputs
)
return
pretrainer
official/nlp/projects/mobilebert/run_distillation.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
"""Creating the task and start trainer."""
import
pprint
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
config_definitions
as
cfg
from
official.core
import
train_utils
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling.progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.projects.mobilebert
import
distillation
FLAGS
=
flags
.
FLAGS
optimization_config
=
optimization
.
OptimizationConfig
(
optimizer
=
optimization
.
OptimizerConfig
(
type
=
'lamb'
,
lamb
=
optimization
.
LAMBConfig
(
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'bias'
,
'norm'
],
clipnorm
=
1.0
)),
learning_rate
=
optimization
.
LrConfig
(
type
=
'polynomial'
,
polynomial
=
optimization
.
PolynomialLrConfig
(
initial_learning_rate
=
1.5e-3
,
decay_steps
=
10000
,
end_learning_rate
=
1.5e-3
)),
warmup
=
optimization
.
WarmupConfig
(
type
=
'linear'
,
linear
=
optimization
.
LinearWarmupConfig
(
warmup_learning_rate
=
0
)))
# copy from progressive/utils.py due to the private visibility issue.
def
config_override
(
params
,
flags_obj
):
"""Override ExperimentConfig according to flags."""
# Change runtime.tpu to the real tpu.
params
.
override
({
'runtime'
:
{
'tpu'
:
flags_obj
.
tpu
,
}
})
# Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for
config_file
in
flags_obj
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if
flags_obj
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
flags_obj
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
pp
=
pprint
.
PrettyPrinter
()
logging
.
info
(
'Final experiment parameters: %s'
,
pp
.
pformat
(
params
.
as_dict
()))
model_dir
=
flags_obj
.
model_dir
if
'train'
in
flags_obj
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
return
params
def
get_exp_config
():
"""Get ExperimentConfig."""
params
=
cfg
.
ExperimentConfig
(
task
=
distillation
.
BertDistillationTaskConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(
progressive
=
distillation
.
BertDistillationProgressiveConfig
(),
optimizer_config
=
optimization_config
,
train_steps
=
740000
,
checkpoint_interval
=
20000
))
return
config_override
(
params
,
FLAGS
)
def
main
(
_
):
logging
.
info
(
'Parsing config files...'
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
get_exp_config
()
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
,
use_experimental_api
=
True
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
distillation
.
BertDistillationTask
(
strategy
=
distribution_strategy
,
progressive
=
params
.
trainer
.
progressive
,
optimizer_config
=
params
.
trainer
.
optimizer_config
,
task_config
=
params
.
task
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
FLAGS
.
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/nlp/projects/mobilebert/tf2_model_checkpoint_converter.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint converter for Mobilebert."""
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
official.nlp.projects.mobilebert
import
model_utils
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core mobilebert layers."
)
flags
.
DEFINE_string
(
"tf1_checkpoint_path"
,
None
,
"Path to load tf1 checkpoint."
)
flags
.
DEFINE_string
(
"tf2_checkpoint_path"
,
None
,
"Path to save tf2 checkpoint."
)
flags
.
DEFINE_boolean
(
"use_model_prefix"
,
False
,
(
"If use model name as prefix for variables. Turn this"
"flag on when the converted checkpoint is used for model"
"in subclass implementation, which uses the model name as"
"prefix for all variable names."
))
def
_bert_name_replacement
(
var_name
,
name_replacements
):
"""Gets the variable name replacement."""
for
src_pattern
,
tgt_pattern
in
name_replacements
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
_get_permutation
(
name
,
permutations
):
"""Checks whether a variable requires transposition by pattern matching."""
for
src_pattern
,
permutation
in
permutations
:
if
src_pattern
in
name
:
logging
.
info
(
"Permuted: %s --> %s"
,
name
,
permutation
)
return
permutation
return
None
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
if
"attention/attention_output/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"attention/attention_output/bias"
in
name
:
return
shape
patterns
=
[
"attention/query"
,
"attention/value"
,
"attention/key"
]
for
pattern
in
patterns
:
if
pattern
in
name
:
if
"kernel"
in
name
:
return
tuple
([
shape
[
0
],
num_heads
,
shape
[
1
]
//
num_heads
])
if
"bias"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
])
return
None
def
convert
(
checkpoint_from_path
,
checkpoint_to_path
,
name_replacements
,
permutations
,
bert_config
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
bert_config: A `BertConfig` to create the core model.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
last_ffn_layer_id
=
str
(
bert_config
.
num_feedforward_networks
-
1
)
name_replacements
=
[
(
x
[
0
],
x
[
1
].
replace
(
"LAST_FFN_LAYER_ID"
,
last_ffn_layer_id
))
for
x
in
name_replacements
]
output_dir
,
_
=
os
.
path
.
split
(
checkpoint_to_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir
=
os
.
path
.
join
(
output_dir
,
"temp_v1"
)
temporary_checkpoint
=
os
.
path
.
join
(
temporary_checkpoint_dir
,
"ckpt"
)
with
tf
.
Graph
().
as_default
():
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
# Get the original tensor data.
tensor
=
reader
.
get_tensor
(
var_name
)
# Look up the new variable name, if any.
new_var_name
=
_bert_name_replacement
(
var_name
,
name_replacements
)
# See if we need to reshape the underlying tensor.
new_shape
=
None
if
bert_config
.
num_attention_heads
>
0
:
new_shape
=
_get_new_shape
(
new_var_name
,
tensor
.
shape
,
bert_config
.
num_attention_heads
)
if
new_shape
:
logging
.
info
(
"Veriable %s has a shape change from %s to %s"
,
var_name
,
tensor
.
shape
,
new_shape
)
tensor
=
np
.
reshape
(
tensor
,
new_shape
)
# See if we need to permute the underlying tensor.
permutation
=
_get_permutation
(
var_name
,
permutations
)
if
permutation
:
tensor
=
np
.
transpose
(
tensor
,
permutation
)
# Create a new variable with the possibly-reshaped or transposed tensor.
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
# Save the variable into the new variable map.
new_variable_map
[
new_var_name
]
=
var
# Keep a list of converter variables for sanity checking.
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
logging
.
info
(
"Writing checkpoint_to_path %s"
,
temporary_checkpoint
)
saver
.
save
(
sess
,
temporary_checkpoint
,
write_meta_graph
=
False
)
logging
.
info
(
"Summary:"
)
logging
.
info
(
"Converted %d variable name(s)."
,
len
(
new_variable_map
))
logging
.
info
(
"Converted: %s"
,
str
(
conversion_map
))
mobilebert_model
=
model_utils
.
create_mobilebert_pretrainer
(
bert_config
)
create_v2_checkpoint
(
mobilebert_model
,
temporary_checkpoint
,
checkpoint_to_path
)
# Clean up the temporary checkpoint, if it exists.
try
:
tf
.
io
.
gfile
.
rmtree
(
temporary_checkpoint_dir
)
except
tf
.
errors
.
OpError
:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def
create_v2_checkpoint
(
model
,
src_checkpoint
,
output_path
):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model
.
load_weights
(
src_checkpoint
).
assert_existing_objects_matched
()
checkpoint
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
checkpoint
.
save
(
output_path
)
_NAME_REPLACEMENT
=
[
# prefix path replacement
(
"bert/"
,
"mobile_bert_encoder/"
),
(
"encoder/layer_"
,
"transformer_layer_"
),
# embedding layer
(
"embeddings/embedding_transformation"
,
"mobile_bert_embedding/embedding_projection"
),
(
"embeddings/position_embeddings"
,
"mobile_bert_embedding/position_embedding/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"mobile_bert_embedding/type_embedding/embeddings"
),
(
"embeddings/word_embeddings"
,
"mobile_bert_embedding/word_embedding/embeddings"
),
(
"embeddings/FakeLayerNorm"
,
"mobile_bert_embedding/embedding_norm"
),
(
"embeddings/LayerNorm"
,
"mobile_bert_embedding/embedding_norm"
),
# attention layer
(
"attention/output/dense"
,
"attention/attention_output"
),
(
"attention/output/FakeLayerNorm"
,
"attention/norm"
),
(
"attention/output/LayerNorm"
,
"attention/norm"
),
(
"attention/self"
,
"attention"
),
# input bottleneck
(
"bottleneck/input/dense"
,
"bottleneck_input/dense"
),
(
"bottleneck/input/FakeLayerNorm"
,
"bottleneck_input/norm"
),
(
"bottleneck/input/LayerNorm"
,
"bottleneck_input/norm"
),
(
"bottleneck/attention/dense"
,
"kq_shared_bottleneck/dense"
),
(
"bottleneck/attention/FakeLayerNorm"
,
"kq_shared_bottleneck/norm"
),
(
"bottleneck/attention/LayerNorm"
,
"kq_shared_bottleneck/norm"
),
# ffn layer
(
"ffn_layer_0/output/dense"
,
"ffn_layer_0/output_dense"
),
(
"ffn_layer_1/output/dense"
,
"ffn_layer_1/output_dense"
),
(
"ffn_layer_2/output/dense"
,
"ffn_layer_2/output_dense"
),
(
"output/dense"
,
"ffn_layer_LAST_FFN_LAYER_ID/output_dense"
),
(
"ffn_layer_0/output/FakeLayerNorm"
,
"ffn_layer_0/norm"
),
(
"ffn_layer_0/output/LayerNorm"
,
"ffn_layer_0/norm"
),
(
"ffn_layer_1/output/FakeLayerNorm"
,
"ffn_layer_1/norm"
),
(
"ffn_layer_1/output/LayerNorm"
,
"ffn_layer_1/norm"
),
(
"ffn_layer_2/output/FakeLayerNorm"
,
"ffn_layer_2/norm"
),
(
"ffn_layer_2/output/LayerNorm"
,
"ffn_layer_2/norm"
),
(
"output/FakeLayerNorm"
,
"ffn_layer_LAST_FFN_LAYER_ID/norm"
),
(
"output/LayerNorm"
,
"ffn_layer_LAST_FFN_LAYER_ID/norm"
),
(
"ffn_layer_0/intermediate/dense"
,
"ffn_layer_0/intermediate_dense"
),
(
"ffn_layer_1/intermediate/dense"
,
"ffn_layer_1/intermediate_dense"
),
(
"ffn_layer_2/intermediate/dense"
,
"ffn_layer_2/intermediate_dense"
),
(
"intermediate/dense"
,
"ffn_layer_LAST_FFN_LAYER_ID/intermediate_dense"
),
# output bottleneck
(
"output/bottleneck/FakeLayerNorm"
,
"bottleneck_output/norm"
),
(
"output/bottleneck/LayerNorm"
,
"bottleneck_output/norm"
),
(
"output/bottleneck/dense"
,
"bottleneck_output/dense"
),
# pooler layer
(
"pooler/dense"
,
"pooler"
),
# MLM layer
(
"cls/predictions"
,
"bert/cls/predictions"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
)
]
_EXCLUDE_PATTERNS
=
[
"cls/seq_relationship"
,
"global_step"
]
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
if
not
FLAGS
.
use_model_prefix
:
_NAME_REPLACEMENT
[
0
]
=
(
"bert/"
,
""
)
bert_config
=
model_utils
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
convert
(
FLAGS
.
tf1_checkpoint_path
,
FLAGS
.
tf2_checkpoint_path
,
_NAME_REPLACEMENT
,
[],
bert_config
,
_EXCLUDE_PATTERNS
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/nlp/projects/mobilebert/utils.py
0 → 100644
View file @
7687b1d3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions."""
import
numpy
as
np
def
generate_fake_input
(
batch_size
=
1
,
seq_len
=
5
,
vocab_size
=
10000
,
seed
=
0
):
"""Generate consistent fake integer input sequences."""
np
.
random
.
seed
(
seed
)
fake_input
=
[]
for
_
in
range
(
batch_size
):
fake_input
.
append
([])
for
_
in
range
(
seq_len
):
fake_input
[
-
1
].
append
(
np
.
random
.
randint
(
0
,
vocab_size
))
fake_input
=
np
.
asarray
(
fake_input
)
return
fake_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment