Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4085c19a
Commit
4085c19a
authored
Oct 26, 2020
by
Elizabeth Kemp
Committed by
A. Unique TensorFlower
Oct 26, 2020
Browse files
Add support for SQuAD BERT export
PiperOrigin-RevId: 339018438
parent
03ae8d2d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
9 deletions
+42
-9
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+42
-9
No files found.
official/nlp/bert/export_tfhub.py
View file @
4085c19a
...
@@ -36,9 +36,12 @@ flags.DEFINE_string("model_checkpoint_path", None,
...
@@ -36,9 +36,12 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file"
)
"name of --vocab_file"
)
flags
.
DEFINE_enum
(
"model_type"
,
"encoder"
,
[
"encoder"
,
"squad"
],
"What kind of BERT model to export."
)
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
...
@@ -69,8 +72,10 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
...
@@ -69,8 +72,10 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
model_checkpoint_path
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
hub_destination
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
# in the vocab file name
...
@@ -79,7 +84,8 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
...
@@ -79,7 +84,8 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
,
# Legacy checkpoints.
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
,
# Legacy checkpoints.
encoder
=
encoder
)
encoder
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
...
@@ -87,10 +93,37 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
...
@@ -87,10 +93,37 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
export_bert_squad_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
span_labeling
,
_
=
bert_models
.
squad_model
(
bert_config
,
max_seq_length
=
None
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
span_labeling
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
span_labeling
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
span_labeling
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
span_labeling
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
def
main
(
_
):
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
if
FLAGS
.
model_type
==
"encoder"
:
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
elif
FLAGS
.
model_type
==
"squad"
:
export_bert_squad_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported model_type %s."
%
FLAGS
.
model_type
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment