Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
12f9403f
"vscode:/vscode.git/clone" did not exist on "f565b808ed3208c2065b1ba889589eafadea0102"
Commit
12f9403f
authored
Aug 26, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 26, 2019
Browse files
Open source checkpoint conversion tool.
PiperOrigin-RevId: 265490374
parent
4d09de12
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
237 additions
and
0 deletions
+237
-0
official/bert/tools/tf1_to_keras_checkpoint_converter.py
official/bert/tools/tf1_to_keras_checkpoint_converter.py
+139
-0
official/bert/tools/tf2_checkpoint_converter.py
official/bert/tools/tf2_checkpoint_converter.py
+98
-0
No files found.
official/bert/tools/tf1_to_keras_checkpoint_converter.py
0 → 100644
View file @
12f9403f
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Convert checkpoints created by Estimator (tf1) to be Keras compatible.
Keras manages variable names internally, which results in subtly different names
for variables between the Estimator and Keras version.
The script should be ran with TF 1.x.
Usage:
python checkpoint_convert.py \
--checkpoint_from_path="/path/to/checkpoint" \
--checkpoint_to_path="/path/to/new_checkpoint"
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
import
tensorflow
as
tf
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
## Required parameters
flags
.
DEFINE_string
(
"checkpoint_from_path"
,
None
,
"Source BERT checkpoint path."
)
flags
.
DEFINE_string
(
"checkpoint_to_path"
,
None
,
"Destination BERT checkpoint path."
)
flags
.
DEFINE_string
(
"exclude_patterns"
,
None
,
"Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint."
)
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS
=
[
(
"bert"
,
"bert_model"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"embedding_postprocessor/type_embeddings"
),
(
"embeddings/position_embeddings"
,
"embedding_postprocessor/position_embeddings"
),
(
"embeddings/LayerNorm"
,
"embedding_postprocessor/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
]
def
_bert_name_replacement
(
var_name
):
for
src_pattern
,
tgt_pattern
in
BERT_NAME_REPLACEMENTS
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
tf
.
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
convert_names
(
checkpoint_from_path
,
checkpoint_to_path
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with
tf
.
Graph
().
as_default
():
tf
.
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
new_var_name
=
_bert_name_replacement
(
var_name
)
tensor
=
reader
.
get_tensor
(
var_name
)
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
new_variable_map
[
new_var_name
]
=
var
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
logging
.
info
(
"Writing checkpoint_to_path %s"
,
checkpoint_to_path
)
saver
.
save
(
sess
,
checkpoint_to_path
)
tf
.
logging
.
info
(
"Summary:"
)
tf
.
logging
.
info
(
" Converted %d variable name(s)."
,
len
(
new_variable_map
))
tf
.
logging
.
info
(
" Converted: %s"
,
str
(
conversion_map
))
def
main
(
_
):
exclude_patterns
=
None
if
FLAGS
.
exclude_patterns
:
exclude_patterns
=
FLAGS
.
exclude_patterns
.
split
(
","
)
convert_names
(
FLAGS
.
checkpoint_from_path
,
FLAGS
.
checkpoint_to_path
,
exclude_patterns
)
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"checkpoint_from_path"
)
flags
.
mark_flag_as_required
(
"checkpoint_to_path"
)
app
.
run
(
main
)
official/bert/tools/tf2_checkpoint_converter.py
0 → 100644
View file @
12f9403f
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter for BERT name-based checkpoint to object-based checkpoint.
The conversion will yield objected-oriented checkpoint for TF2 Bert models,
when BergConfig.backward_compatible is true.
The variable/tensor shapes matches TF1 BERT model, but backward compatiblity
introduces unnecessary reshape compuation.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.bert
import
modeling
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
"init_checkpoint"
,
None
,
"Initial checkpoint (usually from a pre-trained BERT model)."
)
flags
.
DEFINE_string
(
"converted_checkpoint"
,
None
,
"Path to objected-based V2 checkpoint."
)
flags
.
DEFINE_bool
(
"export_bert_as_layer"
,
False
,
"Whether to use a layer rather than a model inside the checkpoint."
)
def
create_bert_model
(
bert_config
):
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A BertConfig` to create the core model.
Returns:
A keras model.
"""
max_seq_length
=
bert_config
.
max_position_embeddings
# Adds input layers just as placeholders.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_word_ids"
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_mask"
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
"input_type_ids"
)
core_model
=
modeling
.
get_bert_model
(
input_word_ids
,
input_mask
,
input_type_ids
,
config
=
bert_config
,
name
=
"bert_model"
,
float_type
=
tf
.
float32
)
return
core_model
def
convert_checkpoint
():
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
# Sets backward_compatible to true to convert TF1 BERT checkpoints.
bert_config
.
backward_compatible
=
True
core_model
=
create_bert_model
(
bert_config
)
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
core_model
.
load_weights
(
FLAGS
.
init_checkpoint
)
if
FLAGS
.
export_bert_as_layer
:
bert_layer
=
core_model
.
get_layer
(
"bert_model"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
bert_layer
=
bert_layer
)
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
core_model
)
checkpoint
.
save
(
FLAGS
.
converted_checkpoint
)
def
main
(
_
):
tf
.
enable_eager_execution
()
convert_checkpoint
()
if
__name__
==
"__main__"
:
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment