Commit aa870ff4 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Fix export_tfhub module with BertV2.

PiperOrigin-RevId: 431236080
parent 34a93745
...@@ -84,13 +84,13 @@ def _create_model( ...@@ -84,13 +84,13 @@ def _create_model(
"""Creates the model to export and the model to restore the checkpoint. """Creates the model to export and the model to restore the checkpoint.
Args: Args:
bert_config: A legacy `BertConfig` to create a `BertEncoder` object. bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
Exactly one of encoder_config and bert_config must be set. one of encoder_config and bert_config must be set.
encoder_config: An `EncoderConfig` to create an encoder of the configured encoder_config: An `EncoderConfig` to create an encoder of the configured
type (`BertEncoder` or other). type (`BertEncoder` or other).
with_mlm: A bool to control the second component of the result. with_mlm: A bool to control the second component of the result. If True,
If True, will create a `BertPretrainerV2` object; otherwise, will will create a `BertPretrainerV2` object; otherwise, will create a
create a `BertEncoder` object. `BertEncoder` object.
Returns: Returns:
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2` A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
...@@ -110,7 +110,11 @@ def _create_model( ...@@ -110,7 +110,11 @@ def _create_model(
# Convert from list of named inputs to dict of inputs keyed by name. # Convert from list of named inputs to dict of inputs keyed by name.
# Only the latter accepts a dict of inputs after restoring from SavedModel. # Only the latter accepts a dict of inputs after restoring from SavedModel.
if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
encoder_inputs_dict = {x.name: x for x in encoder.inputs} encoder_inputs_dict = {x.name: x for x in encoder.inputs}
else:
# encoder.inputs by default is dict for BertEncoderV2.
encoder_inputs_dict = encoder.inputs
encoder_output_dict = encoder(encoder_inputs_dict) encoder_output_dict = encoder(encoder_inputs_dict)
# For interchangeability with other text representations, # For interchangeability with other text representations,
# add "default" as an alias for BERT's whole-input reptesentations. # add "default" as an alias for BERT's whole-input reptesentations.
...@@ -206,15 +210,16 @@ def export_model(export_path: Text, ...@@ -206,15 +210,16 @@ def export_model(export_path: Text,
encoder_config: An optional `encoders.EncoderConfig` object. encoder_config: An optional `encoders.EncoderConfig` object.
model_checkpoint_path: The path to the checkpoint. model_checkpoint_path: The path to the checkpoint.
with_mlm: Whether to export the additional mlm sub-object. with_mlm: Whether to export the additional mlm sub-object.
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
used in the next sentence prediction task to the encoder. in the next sentence prediction task to the encoder.
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set.
do_lower_case: Whether to lower-case text before tokenization. do_lower_case: Whether to lower-case text before tokenization.
""" """
if with_mlm: if with_mlm:
core_model, pretrainer = _create_model(bert_config=bert_config, core_model, pretrainer = _create_model(
bert_config=bert_config,
encoder_config=encoder_config, encoder_config=encoder_config,
with_mlm=with_mlm) with_mlm=with_mlm)
encoder = pretrainer.encoder_network encoder = pretrainer.encoder_network
...@@ -223,7 +228,8 @@ def export_model(export_path: Text, ...@@ -223,7 +228,8 @@ def export_model(export_path: Text,
checkpoint_items = pretrainer.checkpoint_items checkpoint_items = pretrainer.checkpoint_items
checkpoint = tf.train.Checkpoint(**checkpoint_items) checkpoint = tf.train.Checkpoint(**checkpoint_items)
else: else:
core_model, encoder = _create_model(bert_config=bert_config, core_model, encoder = _create_model(
bert_config=bert_config,
encoder_config=encoder_config, encoder_config=encoder_config,
with_mlm=with_mlm) with_mlm=with_mlm)
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
...@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint): ...@@ -279,21 +285,26 @@ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
# overridable. Having this dynamically determined default argument # overridable. Having this dynamically determined default argument
# requires self.__call__ to be defined in this indirect way. # requires self.__call__ to be defined in this indirect way.
default_seq_length = bert_pack_inputs.seq_length default_seq_length = bert_pack_inputs.seq_length
@tf.function(autograph=False) @tf.function(autograph=False)
def call(inputs, seq_length=default_seq_length): def call(inputs, seq_length=default_seq_length):
return layers.BertPackInputs.bert_pack_inputs( return layers.BertPackInputs.bert_pack_inputs(
inputs, seq_length=seq_length, inputs,
seq_length=seq_length,
start_of_sequence_id=bert_pack_inputs.start_of_sequence_id, start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
end_of_segment_id=bert_pack_inputs.end_of_segment_id, end_of_segment_id=bert_pack_inputs.end_of_segment_id,
padding_id=bert_pack_inputs.padding_id) padding_id=bert_pack_inputs.padding_id)
self.__call__ = call self.__call__ = call
for ragged_rank in range(1, 3): for ragged_rank in range(1, 3):
for num_segments in range(1, 3): for num_segments in range(1, 3):
_ = self.__call__.get_concrete_function( _ = self.__call__.get_concrete_function([
[tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32) tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
for _ in range(num_segments)], for _ in range(num_segments)
seq_length=tf.TensorSpec([], tf.int32)) ],
seq_length=tf.TensorSpec(
[], tf.int32))
def create_preprocessing(*, def create_preprocessing(*,
...@@ -311,14 +322,14 @@ def create_preprocessing(*, ...@@ -311,14 +322,14 @@ def create_preprocessing(*,
Args: Args:
vocab_file: The path to the wordpiece vocab file, or None. vocab_file: The path to the wordpiece vocab file, or None.
sp_model_file: The path to the sentencepiece model file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly
Exactly one of vocab_file and sp_model_file must be set. one of vocab_file and sp_model_file must be set. This determines the type
This determines the type of tokenzer that is used. of tokenzer that is used.
do_lower_case: Whether to do lower case. do_lower_case: Whether to do lower case.
tokenize_with_offsets: Whether to include the .tokenize_with_offsets tokenize_with_offsets: Whether to include the .tokenize_with_offsets
subobject. subobject.
default_seq_length: The sequence length of preprocessing results from default_seq_length: The sequence length of preprocessing results from root
root callable. This is also the default sequence length for the callable. This is also the default sequence length for the
bert_pack_inputs subobject. bert_pack_inputs subobject.
Returns: Returns:
...@@ -378,7 +389,8 @@ def create_preprocessing(*, ...@@ -378,7 +389,8 @@ def create_preprocessing(*,
def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]: def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
"""Returns new path with same basename and hash of original path.""" """Returns new path with same basename and hash of original path."""
if file_path is None: return None if file_path is None:
return None
olddir, filename = os.path.split(file_path) olddir, filename = os.path.split(file_path)
hasher = hashlib.sha1() hasher = hashlib.sha1()
hasher.update(olddir.encode("utf-8")) hasher.update(olddir.encode("utf-8"))
...@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path): ...@@ -460,12 +472,17 @@ def _check_no_assert(saved_model_path):
assert_nodes = [] assert_nodes = []
graph_def = saved_model.meta_graphs[0].graph_def graph_def = saved_model.meta_graphs[0].graph_def
assert_nodes += ["node '{}' in global graph".format(n.name) assert_nodes += [
for n in graph_def.node if n.op == "Assert"] "node '{}' in global graph".format(n.name)
for n in graph_def.node
if n.op == "Assert"
]
for fdef in graph_def.library.function: for fdef in graph_def.library.function:
assert_nodes += [ assert_nodes += [
"node '{}' in function '{}'".format(n.name, fdef.signature.name) "node '{}' in function '{}'".format(n.name, fdef.signature.name)
for n in fdef.node_def if n.op == "Assert"] for n in fdef.node_def
if n.op == "Assert"
]
if assert_nodes: if assert_nodes:
raise AssertionError( raise AssertionError(
"Internal tool error: " "Internal tool error: "
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment