Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
da228b42
Commit
da228b42
authored
Dec 02, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 02, 2019
Browse files
Move tf2_encoder_checkpoint_converter to public.
PiperOrigin-RevId: 283374562
parent
494cf0b3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
82 deletions
+109
-82
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
+0
-82
official/nlp/bert/tf2_encoder_checkpoint_converter.py
official/nlp/bert/tf2_encoder_checkpoint_converter.py
+109
-0
No files found.
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
deleted
100644 → 0
View file @
494cf0b3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Convert checkpoints created by Estimator (tf1) to be Keras compatible.
Keras manages variable names internally, which results in subtly different names
for variables between the Estimator and Keras version.
The script should be used with TF 1.x.
Usage:
python checkpoint_convert.py \
--checkpoint_from_path="/path/to/checkpoint" \
--checkpoint_to_path="/path/to/new_checkpoint"
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
import
tensorflow
as
tf
# TF 1.x
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
## Required parameters
flags
.
DEFINE_string
(
"checkpoint_from_path"
,
None
,
"Source BERT checkpoint path."
)
flags
.
DEFINE_string
(
"checkpoint_to_path"
,
None
,
"Destination BERT checkpoint path."
)
flags
.
DEFINE_string
(
"exclude_patterns"
,
None
,
"Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint."
)
flags
.
DEFINE_integer
(
"num_heads"
,
-
1
,
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
flags
.
DEFINE_boolean
(
"create_v2_checkpoint"
,
False
,
"Whether to create a checkpoint compatible with KerasBERT V2 modeling code."
)
def
main
(
_
):
exclude_patterns
=
None
if
FLAGS
.
exclude_patterns
:
exclude_patterns
=
FLAGS
.
exclude_patterns
.
split
(
","
)
if
FLAGS
.
create_v2_checkpoint
:
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_V2_NAME_REPLACEMENTS
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
else
:
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_NAME_REPLACEMENTS
permutations
=
tf1_checkpoint_converter_lib
.
BERT_PERMUTATIONS
tf1_checkpoint_converter_lib
.
convert
(
FLAGS
.
checkpoint_from_path
,
FLAGS
.
checkpoint_to_path
,
FLAGS
.
num_heads
,
name_replacements
,
permutations
,
exclude_patterns
)
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"checkpoint_from_path"
)
flags
.
mark_flag_as_required
(
"checkpoint_to_path"
)
app
.
run
(
main
)
official/nlp/bert/tf2_checkpoint_converter.py
→
official/nlp/bert/tf2_
encoder_
checkpoint_converter.py
View file @
da228b42
...
@@ -12,83 +12,97 @@
...
@@ -12,83 +12,97 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""A converter f
or BERT name-based
checkpoint to
object-based
checkpoint.
"""A converter f
rom a V1 BERT encoder
checkpoint to
a V2 encoder
checkpoint.
The conversion will yield objected-oriented checkpoint for TF2 Bert models,
The conversion will yield an object-oriented checkpoint that can be used
when BergConfig.backward_compatible is true.
to restore a TransformerEncoder object.
The variable/tensor shapes matches TF1 BERT model, but backward compatiblity
introduces unnecessary reshape compuation.
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
# TF 1.x
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.modeling
import
networks
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers."
)
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"init_checkpoint"
,
None
,
"checkpoint_to_convert"
,
None
,
"Initial checkpoint (usually from a pre-trained BERT model)."
)
"Initial checkpoint from a pretrained BERT model core (that is, only the "
flags
.
DEFINE_string
(
"converted_checkpoint"
,
None
,
"BertModel, with no task heads.)"
)
"Path to objected-based V2 checkpoint."
)
flags
.
DEFINE_string
(
"converted_checkpoint_path"
,
None
,
flags
.
DEFINE_bool
(
"Name for the created object-based V2 checkpoint."
)
"export_bert_as_layer"
,
False
,
"Whether to use a layer rather than a model inside the checkpoint."
)
def
create_bert_model
(
bert_confi
g
):
def
_
create_bert_model
(
cf
g
):
"""Creates a BERT keras core model from BERT configuration.
"""Creates a BERT keras core model from BERT configuration.
Args:
Args:
bert_confi
g: A BertConfig` to create the core model.
cf
g: A
`
BertConfig` to create the core model.
Returns:
Returns:
A keras model.
A keras model.
"""
"""
max_seq_length
=
bert_config
.
max_position_embeddings
bert_encoder
=
networks
.
TransformerEncoder
(
vocab_size
=
cfg
.
vocab_size
,
# Adds input layers just as placeholders.
hidden_size
=
cfg
.
hidden_size
,
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
num_layers
=
cfg
.
num_hidden_layers
,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
)
num_attention_heads
=
cfg
.
num_attention_heads
,
input_mask
=
tf
.
keras
.
layers
.
Input
(
intermediate_size
=
cfg
.
intermediate_size
,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
activation
=
activations
.
gelu
,
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
dropout_rate
=
cfg
.
hidden_dropout_prob
,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
)
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
core_model
=
modeling
.
get_bert_model
(
sequence_length
=
cfg
.
max_position_embeddings
,
input_word_ids
,
type_vocab_size
=
cfg
.
type_vocab_size
,
input_mask
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
input_type_ids
,
stddev
=
cfg
.
initializer_range
))
config
=
bert_config
,
name
=
"bert_model"
,
return
bert_encoder
float_type
=
tf
.
float32
)
return
core_model
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
def
convert_checkpoint
():
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
# Create a temporary V1 name-converted checkpoint in the output directory.
core_model
=
create_bert_model
(
bert_config
)
temporary_checkpoint_dir
=
os
.
path
.
join
(
output_dir
,
"temp_v1"
)
temporary_checkpoint
=
os
.
path
.
join
(
temporary_checkpoint_dir
,
"ckpt"
)
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
tf1_checkpoint_converter_lib
.
convert
(
core_model
.
load_weights
(
FLAGS
.
init_checkpoint
)
checkpoint_from_path
=
v1_checkpoint
,
if
FLAGS
.
export_bert_as_layer
:
checkpoint_to_path
=
temporary_checkpoint
,
bert_layer
=
core_model
.
get_layer
(
"bert_model"
)
num_heads
=
bert_config
.
num_attention_heads
,
checkpoint
=
tf
.
train
.
Checkpoint
(
bert_layer
=
bert_layer
)
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_V2_NAME_REPLACEMENTS
,
else
:
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
,
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
core_model
)
exclude_patterns
=
[
"adam"
,
"Adam"
])
checkpoint
.
save
(
FLAGS
.
converted_checkpoint
)
# Create a V2 checkpoint from the temporary checkpoint.
model
=
_create_bert_model
(
bert_config
)
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
output_path
)
# Clean up the temporary checkpoint, if it exists.
try
:
tf
.
io
.
gfile
.
rmtree
(
temporary_checkpoint_dir
)
except
tf
.
errors
.
OpError
:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def
main
(
_
):
def
main
(
_
):
tf
.
enable_eager_execution
()
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
convert_checkpoint
()
output_path
=
FLAGS
.
converted_checkpoint_path
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment