Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31711f66
Commit
31711f66
authored
Apr 15, 2020
by
Elizabeth Kemp
Committed by
A. Unique TensorFlower
Apr 15, 2020
Browse files
Add flag for do_lower_case in export_tfhub.py
PiperOrigin-RevId: 306669872
parent
89599a23
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+13
-4
No files found.
official/nlp/bert/export_tfhub.py
View file @
31711f66
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
typing
import
Text
from
official.nlp.bert
import
bert_models
...
...
@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file"
)
def
create_bert_model
(
bert_config
:
configs
.
BertConfig
)
->
tf
.
keras
.
Model
:
...
...
@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def
export_bert_tfhub
(
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
):
vocab_file
:
Text
,
do_lower_case
:
bool
=
None
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_consumed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
"uncased"
in
vocab_file
,
trainable
=
False
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_bert_tfhub
(
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
)
FLAGS
.
vocab_file
,
FLAGS
.
do_lower_case
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment