"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "fa4f442952955acf8fe9fcfb98b600f6ca6081b6"
Commit 1fb4c559 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add vocab_size to Tokenizer.get_special_tokens_dict()

PiperOrigin-RevId: 351758632
parent d0393056
...@@ -221,6 +221,7 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -221,6 +221,7 @@ class BertTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]" * end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]" * padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]" * mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
""" """
return self._special_tokens_dict return self._special_tokens_dict
...@@ -233,6 +234,7 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -233,6 +234,7 @@ class BertTokenizer(tf.keras.layers.Layer):
if tf.executing_eagerly(): if tf.executing_eagerly():
special_token_ids = vocab_table.lookup( special_token_ids = vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
vocab_size = vocab_table.size()
else: else:
# A blast from the past: non-eager init context while building Model. # A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior(). # This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
...@@ -244,13 +246,17 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -244,13 +246,17 @@ class BertTokenizer(tf.keras.layers.Layer):
vocab_file) vocab_file)
special_token_ids_tensor = local_vocab_table.lookup( special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
vocab_size_tensor = local_vocab_table.size()
init_ops = [tf.compat.v1.initialize_all_tables()] init_ops = [tf.compat.v1.initialize_all_tables()]
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
sess.run(init_ops) sess.run(init_ops)
special_token_ids = sess.run(special_token_ids_tensor) special_token_ids, vocab_size = sess.run(
result = dict() [special_token_ids_tensor, vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for k, v in zip(special_tokens, special_token_ids): for k, v in zip(special_tokens, special_token_ids):
v = int(v) # Numpy to Python. v = int(v)
if v >= 0: if v >= 0:
result[k] = v result[k] = v
else: else:
...@@ -425,6 +431,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -425,6 +431,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]" * end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up from "<pad>" * padding_id: looked up from "<pad>"
* mask_id: looked up from "[MASK]" * mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
""" """
return self._special_tokens_dict return self._special_tokens_dict
...@@ -439,6 +446,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -439,6 +446,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
special_token_ids = self._tokenizer.string_to_id( special_token_ids = self._tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens = self._tokenizer.id_to_string(special_token_ids) inverse_tokens = self._tokenizer.id_to_string(special_token_ids)
vocab_size = self._tokenizer.vocab_size()
else: else:
# A blast from the past: non-eager init context while building Model. # A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior(). # This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
...@@ -451,15 +459,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -451,15 +459,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens_tensor = local_tokenizer.id_to_string( inverse_tokens_tensor = local_tokenizer.id_to_string(
special_token_ids_tensor) special_token_ids_tensor)
vocab_size_tensor = local_tokenizer.vocab_size()
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
special_token_ids, inverse_tokens = sess.run( special_token_ids, inverse_tokens, vocab_size = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor]) [special_token_ids_tensor, inverse_tokens_tensor,
result = dict() vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for name, token_id, inverse_token in zip(special_tokens, for name, token_id, inverse_token in zip(special_tokens,
special_token_ids, special_token_ids,
inverse_tokens): inverse_tokens):
if special_tokens[name] == inverse_token: if special_tokens[name] == inverse_token:
result[name] = int(token_id) # Numpy to Python. result[name] = int(token_id)
else: else:
logging.warning( logging.warning(
"Could not find %s as token \"%s\" in sentencepiece model, " "Could not find %s as token \"%s\" in sentencepiece model, "
......
...@@ -130,7 +130,8 @@ class BertTokenizerTest(tf.test.TestCase): ...@@ -130,7 +130,8 @@ class BertTokenizerTest(tf.test.TestCase):
dict(padding_id=1, dict(padding_id=1,
start_of_sequence_id=3, start_of_sequence_id=3,
end_of_segment_id=4, end_of_segment_id=4,
mask_id=5)) mask_id=5,
vocab_size=7))
def test_special_tokens_partial(self): def test_special_tokens_partial(self):
vocab_file = self._make_vocab_file( vocab_file = self._make_vocab_file(
...@@ -140,7 +141,8 @@ class BertTokenizerTest(tf.test.TestCase): ...@@ -140,7 +141,8 @@ class BertTokenizerTest(tf.test.TestCase):
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(), self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0, dict(padding_id=0,
start_of_sequence_id=1, start_of_sequence_id=1,
end_of_segment_id=2)) # No mask_id, end_of_segment_id=2,
vocab_size=3)) # No mask_id,
def test_special_tokens_in_estimator(self): def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
...@@ -252,7 +254,8 @@ class SentencepieceTokenizerTest(tf.test.TestCase): ...@@ -252,7 +254,8 @@ class SentencepieceTokenizerTest(tf.test.TestCase):
dict(padding_id=0, dict(padding_id=0,
start_of_sequence_id=2, start_of_sequence_id=2,
end_of_segment_id=3, end_of_segment_id=3,
mask_id=4)) mask_id=4,
vocab_size=16))
def test_special_tokens_in_estimator(self): def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment