Commit d0393056 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Let BertTokenizer avoid the bad asset path after repeated save/load

reported in https://github.com/tensorflow/tensorflow/issues/46293.

A test case is added to export_tfhub_lib_test.

PiperOrigin-RevId: 351741572
parent 80993c41
......@@ -133,7 +133,16 @@ class BertTokenizer(tf.keras.layers.Layer):
_check_if_tf_text_installed()
self.tokenize_with_offsets = tokenize_with_offsets
self._vocab_table = self._create_vocab_table(vocab_file)
# TODO(b/177326279): Stop storing the vocab table initializer as an
# attribute when https://github.com/tensorflow/tensorflow/issues/46293
# has been fixed in the TensorFlow versions of the TF Hub users that load
# a SavedModel created from this layer. Due to that issue, loading such a
# SavedModel forgets to add .vocab_table._initializer as a trackable
# dependency of .vocab_table, so that saving it again to a second SavedModel
# (e.g., the final model built using TF Hub) does not properly track
# the ._vocab_table._initializer._filename as an Asset.
self._vocab_table, self._vocab_initializer_donotuse = (
self._create_vocab_table_and_initializer(vocab_file))
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
......@@ -144,12 +153,13 @@ class BertTokenizer(tf.keras.layers.Layer):
def vocab_size(self):
return self._vocab_table.size()
def _create_vocab_table(self, vocab_file):
def _create_vocab_table_and_initializer(self, vocab_file):
vocab_initializer = tf.lookup.TextFileInitializer(
vocab_file,
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
return tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
return vocab_table, vocab_initializer
def call(self, inputs: tf.Tensor):
"""Calls text.BertTokenizer on inputs.
......@@ -230,7 +240,8 @@ class BertTokenizer(tf.keras.layers.Layer):
"Non-eager init context; computing "
"BertTokenizer's special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_vocab_table = self._create_vocab_table(vocab_file)
local_vocab_table, _ = self._create_vocab_table_and_initializer(
vocab_file)
special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
init_ops = [tf.compat.v1.initialize_all_tables()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment