Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aa870ff4
Commit
aa870ff4
authored
Feb 27, 2022
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Feb 27, 2022
Browse files
Fix export_tfhub module with BertV2.
PiperOrigin-RevId: 431236080
parent
34a93745
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
303 additions
and
199 deletions
+303
-199
official/nlp/tools/export_tfhub_lib.py
official/nlp/tools/export_tfhub_lib.py
+47
-30
official/nlp/tools/export_tfhub_lib_test.py
official/nlp/tools/export_tfhub_lib_test.py
+256
-169
No files found.
official/nlp/tools/export_tfhub_lib.py
View file @
aa870ff4
...
...
@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint.
Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
Exactly
one of encoder_config and bert_config must be set.
bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
Exactly
one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result.
If True,
will create a `BertPretrainerV2` object; otherwise, will
create a
`BertEncoder` object.
with_mlm: A bool to control the second component of the result.
If True,
will create a `BertPretrainerV2` object; otherwise, will
create a
`BertEncoder` object.
Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
...
...
@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel.
if
isinstance
(
encoder
.
inputs
,
list
)
or
isinstance
(
encoder
.
inputs
,
tuple
):
encoder_inputs_dict
=
{
x
.
name
:
x
for
x
in
encoder
.
inputs
}
else
:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict
=
encoder
.
inputs
encoder_output_dict
=
encoder
(
encoder_inputs_dict
)
# For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations.
...
...
@@ -206,15 +210,16 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer
used
in the next sentence prediction task to the encoder.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer
used
in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization.
"""
if
with_mlm
:
core_model
,
pretrainer
=
_create_model
(
bert_config
=
bert_config
,
core_model
,
pretrainer
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
encoder
=
pretrainer
.
encoder_network
...
...
@@ -223,7 +228,8 @@ def export_model(export_path: Text,
checkpoint_items
=
pretrainer
.
checkpoint_items
checkpoint
=
tf
.
train
.
Checkpoint
(
**
checkpoint_items
)
else
:
core_model
,
encoder
=
_create_model
(
bert_config
=
bert_config
,
core_model
,
encoder
=
_create_model
(
bert_config
=
bert_config
,
encoder_config
=
encoder_config
,
with_mlm
=
with_mlm
)
checkpoint
=
tf
.
train
.
Checkpoint
(
...
...
@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way.
default_seq_length
=
bert_pack_inputs
.
seq_length
@
tf
.
function
(
autograph
=
False
)
def
call
(
inputs
,
seq_length
=
default_seq_length
):
return
layers
.
BertPackInputs
.
bert_pack_inputs
(
inputs
,
seq_length
=
seq_length
,
inputs
,
seq_length
=
seq_length
,
start_of_sequence_id
=
bert_pack_inputs
.
start_of_sequence_id
,
end_of_segment_id
=
bert_pack_inputs
.
end_of_segment_id
,
padding_id
=
bert_pack_inputs
.
padding_id
)
self
.
__call__
=
call
for
ragged_rank
in
range
(
1
,
3
):
for
num_segments
in
range
(
1
,
3
):
_
=
self
.
__call__
.
get_concrete_function
(
[
tf
.
RaggedTensorSpec
([
None
]
*
(
ragged_rank
+
1
),
dtype
=
tf
.
int32
)
for
_
in
range
(
num_segments
)],
seq_length
=
tf
.
TensorSpec
([],
tf
.
int32
))
_
=
self
.
__call__
.
get_concrete_function
([
tf
.
RaggedTensorSpec
([
None
]
*
(
ragged_rank
+
1
),
dtype
=
tf
.
int32
)
for
_
in
range
(
num_segments
)
],
seq_length
=
tf
.
TensorSpec
(
[],
tf
.
int32
))
def
create_preprocessing
(
*
,
...
...
@@ -311,14 +322,14 @@ def create_preprocessing(*,
Args:
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
This determines the type
of tokenzer that is used.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly
one of vocab_file and sp_model_file must be set.
This determines the type
of tokenzer that is used.
do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject.
default_seq_length: The sequence length of preprocessing results from
root
callable. This is also the default sequence length for the
default_seq_length: The sequence length of preprocessing results from
root
callable. This is also the default sequence length for the
bert_pack_inputs subobject.
Returns:
...
...
@@ -378,7 +389,8 @@ def create_preprocessing(*,
def
_move_to_tmpdir
(
file_path
:
Optional
[
Text
],
tmpdir
:
Text
)
->
Optional
[
Text
]:
"""Returns new path with same basename and hash of original path."""
if
file_path
is
None
:
return
None
if
file_path
is
None
:
return
None
olddir
,
filename
=
os
.
path
.
split
(
file_path
)
hasher
=
hashlib
.
sha1
()
hasher
.
update
(
olddir
.
encode
(
"utf-8"
))
...
...
@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path):
assert_nodes
=
[]
graph_def
=
saved_model
.
meta_graphs
[
0
].
graph_def
assert_nodes
+=
[
"node '{}' in global graph"
.
format
(
n
.
name
)
for
n
in
graph_def
.
node
if
n
.
op
==
"Assert"
]
assert_nodes
+=
[
"node '{}' in global graph"
.
format
(
n
.
name
)
for
n
in
graph_def
.
node
if
n
.
op
==
"Assert"
]
for
fdef
in
graph_def
.
library
.
function
:
assert_nodes
+=
[
"node '{}' in function '{}'"
.
format
(
n
.
name
,
fdef
.
signature
.
name
)
for
n
in
fdef
.
node_def
if
n
.
op
==
"Assert"
]
for
n
in
fdef
.
node_def
if
n
.
op
==
"Assert"
]
if
assert_nodes
:
raise
AssertionError
(
"Internal tool error: "
...
...
official/nlp/tools/export_tfhub_lib_test.py
View file @
aa870ff4
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment