Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3a9ed6bd
Commit
3a9ed6bd
authored
Oct 03, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Oct 03, 2020
Browse files
Support to convert albert tf1 ckpt to tf2's BertPretrainerV2
PiperOrigin-RevId: 335226408
parent
f82ade47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
14 deletions
+41
-14
official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
...ial/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
+41
-14
No files found.
official/nlp/albert/tf2_albert_encoder_checkpoint_converter.py
View file @
3a9ed6bd
...
@@ -17,20 +17,16 @@
...
@@ -17,20 +17,16 @@
The conversion will yield an object-oriented checkpoint that can be used
The conversion will yield an object-oriented checkpoint that can be used
to restore an AlbertEncoder object.
to restore an AlbertEncoder object.
"""
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
os
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
activation
s
from
official.modeling
import
tf_util
s
from
official.nlp.albert
import
configs
from
official.nlp.albert
import
configs
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -47,6 +43,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
...
@@ -47,6 +43,10 @@ flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)."
)
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)."
)
flags
.
DEFINE_enum
(
"converted_model"
,
"encoder"
,
[
"encoder"
,
"pretrainer"
],
"Whether to convert the checkpoint to a `AlbertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads)."
)
ALBERT_NAME_REPLACEMENTS
=
(
ALBERT_NAME_REPLACEMENTS
=
(
...
@@ -60,10 +60,10 @@ ALBERT_NAME_REPLACEMENTS = (
...
@@ -60,10 +60,10 @@ ALBERT_NAME_REPLACEMENTS = (
(
"group_0/inner_group_0/"
,
""
),
(
"group_0/inner_group_0/"
,
""
),
(
"attention_1/self"
,
"self_attention"
),
(
"attention_1/self"
,
"self_attention"
),
(
"attention_1/output/dense"
,
"self_attention/attention_output"
),
(
"attention_1/output/dense"
,
"self_attention/attention_output"
),
(
"LayerNorm/"
,
"self_attention_layer_norm/"
),
(
"
transformer/
LayerNorm/"
,
"
transformer/
self_attention_layer_norm/"
),
(
"ffn_1/intermediate/dense"
,
"intermediate"
),
(
"ffn_1/intermediate/dense"
,
"intermediate"
),
(
"ffn_1/intermediate/output/dense"
,
"output"
),
(
"ffn_1/intermediate/output/dense"
,
"output"
),
(
"LayerNorm_1/"
,
"output_layer_norm/"
),
(
"
transformer/
LayerNorm_1/"
,
"
transformer/
output_layer_norm/"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
...
@@ -73,10 +73,10 @@ ALBERT_NAME_REPLACEMENTS = (
...
@@ -73,10 +73,10 @@ ALBERT_NAME_REPLACEMENTS = (
def
_create_albert_model
(
cfg
):
def
_create_albert_model
(
cfg
):
"""Creates a
BERT keras core model from BERT configuration.
"""Creates a
n AL
BERT keras core model from BERT configuration.
Args:
Args:
cfg: A `
B
ertConfig` to create the core model.
cfg: A `
Alb
ertConfig` to create the core model.
Returns:
Returns:
A keras model.
A keras model.
...
@@ -88,7 +88,7 @@ def _create_albert_model(cfg):
...
@@ -88,7 +88,7 @@ def _create_albert_model(cfg):
num_layers
=
cfg
.
num_hidden_layers
,
num_layers
=
cfg
.
num_hidden_layers
,
num_attention_heads
=
cfg
.
num_attention_heads
,
num_attention_heads
=
cfg
.
num_attention_heads
,
intermediate_size
=
cfg
.
intermediate_size
,
intermediate_size
=
cfg
.
intermediate_size
,
activation
=
activations
.
gelu
,
activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
)
,
dropout_rate
=
cfg
.
hidden_dropout_prob
,
dropout_rate
=
cfg
.
hidden_dropout_prob
,
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
max_sequence_length
=
cfg
.
max_position_embeddings
,
max_sequence_length
=
cfg
.
max_position_embeddings
,
...
@@ -98,8 +98,27 @@ def _create_albert_model(cfg):
...
@@ -98,8 +98,27 @@ def _create_albert_model(cfg):
return
albert_encoder
return
albert_encoder
def
_create_pretrainer_model
(
cfg
):
"""Creates a pretrainer with AlbertEncoder from ALBERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
albert_encoder
=
_create_albert_model
(
cfg
)
pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
albert_encoder
,
mlm_activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
))
return
pretrainer
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
):
checkpoint_model_name
,
converted_model
=
"encoder"
):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
...
@@ -115,7 +134,13 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
...
@@ -115,7 +134,13 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
exclude_patterns
=
[
"adam"
,
"Adam"
])
exclude_patterns
=
[
"adam"
,
"Adam"
])
# Create a V2 checkpoint from the temporary checkpoint.
# Create a V2 checkpoint from the temporary checkpoint.
model
=
_create_albert_model
(
bert_config
)
if
converted_model
==
"encoder"
:
model
=
_create_albert_model
(
bert_config
)
elif
converted_model
==
"pretrainer"
:
model
=
_create_pretrainer_model
(
bert_config
)
else
:
raise
ValueError
(
"Unsupported converted_model: %s"
%
converted_model
)
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
output_path
,
output_path
,
checkpoint_model_name
)
checkpoint_model_name
)
...
@@ -132,9 +157,11 @@ def main(_):
...
@@ -132,9 +157,11 @@ def main(_):
output_path
=
FLAGS
.
converted_checkpoint_path
output_path
=
FLAGS
.
converted_checkpoint_path
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
checkpoint_model_name
=
FLAGS
.
checkpoint_model_name
checkpoint_model_name
=
FLAGS
.
checkpoint_model_name
converted_model
=
FLAGS
.
converted_model
albert_config
=
configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
albert_config_file
)
albert_config
=
configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
albert_config_file
)
convert_checkpoint
(
albert_config
,
output_path
,
v1_checkpoint
,
convert_checkpoint
(
albert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
)
checkpoint_model_name
,
converted_model
=
converted_model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment