Commit aa870ff4 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Fix export_tfhub module with BertV2.

PiperOrigin-RevId: 431236080
parent 34a93745
......@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint.
Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
Exactly one of encoder_config and bert_config must be set.
bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result.
If True, will create a `BertPretrainerV2` object; otherwise, will
create a `BertEncoder` object.
with_mlm: A bool to control the second component of the result. If True,
will create a `BertPretrainerV2` object; otherwise, will create a
`BertEncoder` object.
Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
......@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel.
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
else:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict = encoder.inputs
encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations.
......@@ -206,26 +210,28 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer
used in the next sentence prediction task to the encoder.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly one of vocab_file and sp_model_file must be set.
sp_model_file: The path to the sentencepiece model file, or None. Exactly
one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization.
"""
if with_mlm:
core_model, pretrainer = _create_model(bert_config=bert_config,
encoder_config=encoder_config,
with_mlm=with_mlm)
core_model, pretrainer = _create_model(
bert_config=bert_config,
encoder_config=encoder_config,
with_mlm=with_mlm)
encoder = pretrainer.encoder_network
# It supports both the new pretrainer checkpoint produced by TF-NLP and
# the checkpoint converted from TF1 (original BERT, SmallBERTs).
checkpoint_items = pretrainer.checkpoint_items
checkpoint = tf.train.Checkpoint(**checkpoint_items)
else:
core_model, encoder = _create_model(bert_config=bert_config,
encoder_config=encoder_config,
with_mlm=with_mlm)
core_model, encoder = _create_model(
bert_config=bert_config,
encoder_config=encoder_config,
with_mlm=with_mlm)
checkpoint = tf.train.Checkpoint(
model=encoder, # Legacy checkpoints.
encoder=encoder)
......@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way.
default_seq_length = bert_pack_inputs.seq_length
@tf.function(autograph=False)
def call(inputs, seq_length=default_seq_length):
return layers.BertPackInputs.bert_pack_inputs(
inputs, seq_length=seq_length,
inputs,
seq_length=seq_length,
start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
end_of_segment_id=bert_pack_inputs.end_of_segment_id,
padding_id=bert_pack_inputs.padding_id)
self.__call__ = call
for ragged_rank in range(1, 3):
for num_segments in range(1, 3):
_ = self.__call__.get_concrete_function(
[tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
for _ in range(num_segments)],
seq_length=tf.TensorSpec([], tf.int32))
_ = self.__call__.get_concrete_function([
tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
for _ in range(num_segments)
],
seq_length=tf.TensorSpec(
[], tf.int32))
def create_preprocessing(*,
......@@ -311,14 +322,14 @@ def create_preprocessing(*,
Args:
vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None.
Exactly one of vocab_file and sp_model_file must be set.
This determines the type of tokenzer that is used.
sp_model_file: The path to the sentencepiece model file, or None. Exactly
one of vocab_file and sp_model_file must be set. This determines the type
of tokenzer that is used.
do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject.
default_seq_length: The sequence length of preprocessing results from
root callable. This is also the default sequence length for the
default_seq_length: The sequence length of preprocessing results from root
callable. This is also the default sequence length for the
bert_pack_inputs subobject.
Returns:
......@@ -378,7 +389,8 @@ def create_preprocessing(*,
def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
"""Returns new path with same basename and hash of original path."""
if file_path is None: return None
if file_path is None:
return None
olddir, filename = os.path.split(file_path)
hasher = hashlib.sha1()
hasher.update(olddir.encode("utf-8"))
......@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path):
assert_nodes = []
graph_def = saved_model.meta_graphs[0].graph_def
assert_nodes += ["node '{}' in global graph".format(n.name)
for n in graph_def.node if n.op == "Assert"]
assert_nodes += [
"node '{}' in global graph".format(n.name)
for n in graph_def.node
if n.op == "Assert"
]
for fdef in graph_def.library.function:
assert_nodes += [
"node '{}' in function '{}'".format(n.name, fdef.signature.name)
for n in fdef.node_def if n.op == "Assert"]
for n in fdef.node_def
if n.op == "Assert"
]
if assert_nodes:
raise AssertionError(
"Internal tool error: "
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment