Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9269550c
Commit
9269550c
authored
Sep 26, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Sep 26, 2020
Browse files
Internal change
PiperOrigin-RevId: 333949668
parent
164ee0e0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
10 deletions
+52
-10
official/nlp/bert/tf2_encoder_checkpoint_converter.py
official/nlp/bert/tf2_encoder_checkpoint_converter.py
+52
-10
No files found.
official/nlp/bert/tf2_encoder_checkpoint_converter.py
View file @
9269550c
...
...
@@ -15,7 +15,8 @@
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a TransformerEncoder object.
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -27,9 +28,10 @@ from absl import app
from
absl
import
flags
import
tensorflow
as
tf
from
official.modeling
import
activation
s
from
official.modeling
import
tf_util
s
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
FLAGS
=
flags
.
FLAGS
...
...
@@ -46,6 +48,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)."
)
flags
.
DEFINE_enum
(
"converted_model"
,
"encoder"
,
[
"encoder"
,
"pretrainer"
],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads)."
)
def
_create_bert_model
(
cfg
):
...
...
@@ -55,7 +61,7 @@ def _create_bert_model(cfg):
cfg: A `BertConfig` to create the core model.
Returns:
A
Transform
erEncoder net
o
work.
A
B
er
t
Encoder network.
"""
bert_encoder
=
networks
.
BertEncoder
(
vocab_size
=
cfg
.
vocab_size
,
...
...
@@ -63,7 +69,7 @@ def _create_bert_model(cfg):
num_layers
=
cfg
.
num_hidden_layers
,
num_attention_heads
=
cfg
.
num_attention_heads
,
intermediate_size
=
cfg
.
intermediate_size
,
activation
=
activations
.
gelu
,
activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
)
,
dropout_rate
=
cfg
.
hidden_dropout_prob
,
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
max_sequence_length
=
cfg
.
max_position_embeddings
,
...
...
@@ -75,8 +81,29 @@ def _create_bert_model(cfg):
return
bert_encoder
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
=
"model"
):
def
_create_bert_pretrainer_model
(
cfg
):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder
=
_create_bert_model
(
cfg
)
pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
bert_encoder
,
mlm_activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
))
return
pretrainer
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
=
"model"
,
converted_model
=
"encoder"
):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
...
...
@@ -84,6 +111,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir
=
os
.
path
.
join
(
output_dir
,
"temp_v1"
)
temporary_checkpoint
=
os
.
path
.
join
(
temporary_checkpoint_dir
,
"ckpt"
)
tf1_checkpoint_converter_lib
.
convert
(
checkpoint_from_path
=
v1_checkpoint
,
checkpoint_to_path
=
temporary_checkpoint
,
...
...
@@ -92,8 +120,14 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
,
exclude_patterns
=
[
"adam"
,
"Adam"
])
# Create a V2 checkpoint from the temporary checkpoint.
if
converted_model
==
"encoder"
:
model
=
_create_bert_model
(
bert_config
)
elif
converted_model
==
"pretrainer"
:
model
=
_create_bert_pretrainer_model
(
bert_config
)
else
:
raise
ValueError
(
"Unsupported converted_model: %s"
%
converted_model
)
# Create a V2 checkpoint from the temporary checkpoint.
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
output_path
,
checkpoint_model_name
)
...
...
@@ -106,13 +140,21 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
pass
def
main
(
_
):
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
output_path
=
FLAGS
.
converted_checkpoint_path
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
checkpoint_model_name
=
FLAGS
.
checkpoint_model_name
converted_model
=
FLAGS
.
converted_model
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
)
convert_checkpoint
(
bert_config
=
bert_config
,
output_path
=
output_path
,
v1_checkpoint
=
v1_checkpoint
,
checkpoint_model_name
=
checkpoint_model_name
,
converted_model
=
converted_model
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment