Commit 2416dd9c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8256 from stagedml:tokenizer-update

PiperOrigin-RevId: 301949311
parents 27207a27 30579e0f
...@@ -28,6 +28,8 @@ import six ...@@ -28,6 +28,8 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-complex-comprehension
PAD = "<pad>" PAD = "<pad>"
PAD_ID = 0 PAD_ID = 0
EOS = "<EOS>" EOS = "<EOS>"
...@@ -45,12 +47,18 @@ _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") ...@@ -45,12 +47,18 @@ _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_UNDEFINED_UNICODE = u"\u3013" _UNDEFINED_UNICODE = u"\u3013"
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set( def alphanumeric_char_set():
six.unichr(i) for i in xrange(sys.maxunicode) return set(
six.unichr(i)
for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N"))) unicodedata.category(six.unichr(i)).startswith("N")))
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set()
# min_count is the minimum number of times a subtoken must appear in the data # min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary # before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size. # search to obtain the target vocabulary size.
...@@ -61,11 +69,14 @@ _MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count ...@@ -61,11 +69,14 @@ _MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
class Subtokenizer(object): class Subtokenizer(object):
"""Encodes and decodes strings to/from integer IDs.""" """Encodes and decodes strings to/from integer IDs."""
def __init__(self, vocab_file, reserved_tokens=None): def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None):
"""Initializes class, creating a vocab file if data_files is provided.""" """Initializes class, creating a vocab file if data_files is provided."""
tf.compat.v1.logging.info("Initializing Subtokenizer from file %s." % tf.compat.v1.logging.info("Initializing Subtokenizer from file %s." %
vocab_file) vocab_file)
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
if reserved_tokens is None: if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS reserved_tokens = RESERVED_TOKENS
...@@ -78,13 +89,20 @@ class Subtokenizer(object): ...@@ -78,13 +89,20 @@ class Subtokenizer(object):
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken)) self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
# Create cache to speed up subtokenization # Create cache to speed up subtokenization
self._cache_size = 2 ** 20 self._cache_size = 2**20
self._cache = [(None, None)] * self._cache_size self._cache = [(None, None)] * self._cache_size
self._master_char_set = master_char_set
@staticmethod @staticmethod
def init_from_files( def init_from_files(vocab_file,
vocab_file, files, target_vocab_size, threshold, min_count=None, files,
file_byte_limit=1e6, reserved_tokens=None, correct_strip=True): target_vocab_size,
threshold,
min_count=None,
file_byte_limit=1e6,
reserved_tokens=None,
correct_strip=True,
master_char_set=None):
"""Create subtoken vocabulary based on files, and save vocab to file. """Create subtoken vocabulary based on files, and save vocab to file.
Args: Args:
...@@ -101,10 +119,13 @@ class Subtokenizer(object): ...@@ -101,10 +119,13 @@ class Subtokenizer(object):
reserved_tokens: List of string tokens that are guaranteed to be at the reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list. beginning of the subtoken vocabulary list.
correct_strip: Whether to convert text to unicode before strip. correct_strip: Whether to convert text to unicode before strip.
master_char_set: the char set.
Returns: Returns:
Subtokenizer object Subtokenizer object
""" """
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
if reserved_tokens is None: if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS reserved_tokens = RESERVED_TOKENS
...@@ -112,7 +133,8 @@ class Subtokenizer(object): ...@@ -112,7 +133,8 @@ class Subtokenizer(object):
tf.compat.v1.logging.info("Vocab file already exists (%s)" % vocab_file) tf.compat.v1.logging.info("Vocab file already exists (%s)" % vocab_file)
else: else:
tf.compat.v1.logging.info("Begin steps to create subtoken vocabulary...") tf.compat.v1.logging.info("Begin steps to create subtoken vocabulary...")
token_counts = _count_tokens(files, file_byte_limit, correct_strip) token_counts = _count_tokens(files, file_byte_limit, correct_strip,
master_char_set)
alphabet = _generate_alphabet_dict(token_counts) alphabet = _generate_alphabet_dict(token_counts)
subtoken_list = _generate_subtokens_with_target_vocab_size( subtoken_list = _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_vocab_size, threshold, min_count, token_counts, alphabet, target_vocab_size, threshold, min_count,
...@@ -120,15 +142,18 @@ class Subtokenizer(object): ...@@ -120,15 +142,18 @@ class Subtokenizer(object):
tf.compat.v1.logging.info("Generated vocabulary with %d subtokens." % tf.compat.v1.logging.info("Generated vocabulary with %d subtokens." %
len(subtoken_list)) len(subtoken_list))
_save_vocab_file(vocab_file, subtoken_list) _save_vocab_file(vocab_file, subtoken_list)
return Subtokenizer(vocab_file) return Subtokenizer(vocab_file, master_char_set=master_char_set)
def encode(self, raw_string, add_eos=False): def encode(self, raw_string, add_eos=False):
"""Encodes a string into a list of int subtoken ids.""" """Encodes a string into a list of int subtoken ids."""
ret = [] ret = []
tokens = _split_string_to_tokens(native_to_unicode(raw_string)) tokens = _split_string_to_tokens(
native_to_unicode(raw_string), self._master_char_set)
for token in tokens: for token in tokens:
ret.extend(self._token_to_subtoken_ids(token)) ret.extend(self._token_to_subtoken_ids(token))
if add_eos: if add_eos:
assert EOS in self.subtoken_list, \
"Can't append 'EOS' because it is not in list of known subtokens."
ret.append(EOS_ID) ret.append(EOS_ID)
return ret return ret
...@@ -161,13 +186,14 @@ class Subtokenizer(object): ...@@ -161,13 +186,14 @@ class Subtokenizer(object):
"Subtokens argument passed into decode() must be a list of integers.") "Subtokens argument passed into decode() must be a list of integers.")
return _unicode_to_native( return _unicode_to_native(
_join_tokens_to_string(self._subtoken_ids_to_tokens(subtokens))) _join_tokens_to_string(
self._subtoken_ids_to_tokens(subtokens), self._master_char_set))
def _subtoken_ids_to_tokens(self, subtokens): def _subtoken_ids_to_tokens(self, subtokens):
"""Convert list of int subtoken ids to a list of string tokens.""" """Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens = "".join([ escaped_tokens = "".join([
self.subtoken_list[s] for s in subtokens self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list)
if s < len(self.subtoken_list)]) ])
escaped_tokens = escaped_tokens.split("_") escaped_tokens = escaped_tokens.split("_")
# All tokens in the vocabulary list have been escaped (see _escape_token()) # All tokens in the vocabulary list have been escaped (see _escape_token())
...@@ -218,16 +244,16 @@ def _unicode_to_native(s): ...@@ -218,16 +244,16 @@ def _unicode_to_native(s):
return s return s
def _split_string_to_tokens(text): def _split_string_to_tokens(text, master_char_set):
"""Splits text to a list of string tokens.""" """Splits text to a list of string tokens."""
if not text: if not text:
return [] return []
ret = [] ret = []
token_start = 0 token_start = 0
# Classify each character in the input string # Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] is_master = [c in master_char_set for c in text]
for pos in xrange(1, len(text)): for pos in xrange(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]: if is_master[pos] != is_master[pos - 1]:
token = text[token_start:pos] token = text[token_start:pos]
if token != u" " or token_start == 0: if token != u" " or token_start == 0:
ret.append(token) ret.append(token)
...@@ -237,12 +263,12 @@ def _split_string_to_tokens(text): ...@@ -237,12 +263,12 @@ def _split_string_to_tokens(text):
return ret return ret
def _join_tokens_to_string(tokens): def _join_tokens_to_string(tokens, master_char_set):
"""Join a list of string tokens into a single string.""" """Join a list of string tokens into a single string."""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] token_is_master = [t[0] in master_char_set for t in tokens]
ret = [] ret = []
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: if i > 0 and token_is_master[i - 1] and token_is_master[i]:
ret.append(u" ") ret.append(u" ")
ret.append(token) ret.append(token)
return "".join(ret) return "".join(ret)
...@@ -324,7 +350,10 @@ def _unescape_token(token): ...@@ -324,7 +350,10 @@ def _unescape_token(token):
return _UNESCAPE_REGEX.sub(match, token) return _UNESCAPE_REGEX.sub(match, token)
def _count_tokens(files, file_byte_limit=1e6, correct_strip=True): def _count_tokens(files,
file_byte_limit=1e6,
correct_strip=True,
master_char_set=None):
"""Return token counts of words in the files. """Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear Samples file_byte_limit bytes from each file, and counts the words that appear
...@@ -337,11 +366,15 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True): ...@@ -337,11 +366,15 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
vocabulary generation for PY2. Sets correct_strip to False in PY2 to vocabulary generation for PY2. Sets correct_strip to False in PY2 to
reproduce previous common public result. Sets correct_strip to True will reproduce previous common public result. Sets correct_strip to True will
let PY2 and PY3 get a consistent vocabulary. let PY2 and PY3 get a consistent vocabulary.
master_char_set: the char set.
Returns: Returns:
Dictionary mapping tokens to the number of times they appear in the sampled Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files. lines from the files.
""" """
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
token_counts = collections.defaultdict(int) token_counts = collections.defaultdict(int)
for filepath in files: for filepath in files:
...@@ -362,7 +395,8 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True): ...@@ -362,7 +395,8 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
counter = 0 counter = 0
# Add words to token counts # Add words to token counts
for token in _split_string_to_tokens(native_to_unicode(line)): for token in _split_string_to_tokens(
native_to_unicode(line), master_char_set):
token_counts[token] += 1 token_counts[token] += 1
return token_counts return token_counts
...@@ -394,8 +428,11 @@ def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length): ...@@ -394,8 +428,11 @@ def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
return ret return ret
def _generate_subtokens_with_target_vocab_size( def _generate_subtokens_with_target_vocab_size(token_counts,
token_counts, alphabet, target_size, threshold, min_count=None, alphabet,
target_size,
threshold,
min_count=None,
reserved_tokens=None): reserved_tokens=None):
"""Generate subtoken vocabulary close to the target size.""" """Generate subtoken vocabulary close to the target size."""
if reserved_tokens is None: if reserved_tokens is None:
...@@ -449,8 +486,8 @@ def _generate_alphabet_dict(iterable, reserved_tokens=None): ...@@ -449,8 +486,8 @@ def _generate_alphabet_dict(iterable, reserved_tokens=None):
return alphabet return alphabet
def _count_and_gen_subtokens( def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict,
token_counts, alphabet, subtoken_dict, max_subtoken_length): max_subtoken_length):
"""Count number of times subtokens appear, and generate new subtokens. """Count number of times subtokens appear, and generate new subtokens.
Args: Args:
...@@ -468,8 +505,8 @@ def _count_and_gen_subtokens( ...@@ -468,8 +505,8 @@ def _count_and_gen_subtokens(
subtoken_counts = collections.defaultdict(int) subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts): for token, count in six.iteritems(token_counts):
token = _escape_token(token, alphabet) token = _escape_token(token, alphabet)
subtokens = _split_token_to_subtokens( subtokens = _split_token_to_subtokens(token, subtoken_dict,
token, subtoken_dict, max_subtoken_length) max_subtoken_length)
# Generate new subtokens by taking substrings from token. # Generate new subtokens by taking substrings from token.
start = 0 start = 0
...@@ -503,8 +540,10 @@ def _filter_and_bucket_subtokens(subtoken_counts, min_count): ...@@ -503,8 +540,10 @@ def _filter_and_bucket_subtokens(subtoken_counts, min_count):
return subtoken_buckets return subtoken_buckets
def _gen_new_subtoken_list( def _gen_new_subtoken_list(subtoken_counts,
subtoken_counts, min_count, alphabet, reserved_tokens=None): min_count,
alphabet,
reserved_tokens=None):
"""Generate candidate subtokens ordered by count, and new max subtoken length. """Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens Add subtokens to the candiate list in order of length (longest subtokens
...@@ -575,8 +614,10 @@ def _gen_new_subtoken_list( ...@@ -575,8 +614,10 @@ def _gen_new_subtoken_list(
return subtoken_list, max_subtoken_length return subtoken_list, max_subtoken_length
def _generate_subtokens( def _generate_subtokens(token_counts,
token_counts, alphabet, min_count, num_iterations=4, alphabet,
min_count,
num_iterations=4,
reserved_tokens=None): reserved_tokens=None):
"""Create a list of subtokens in decreasing order of frequency. """Create a list of subtokens in decreasing order of frequency.
...@@ -609,8 +650,9 @@ def _generate_subtokens( ...@@ -609,8 +650,9 @@ def _generate_subtokens(
# Create dict mapping subtoken->count, with additional subtokens created # Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens. # from substrings taken from the tokens.
subtoken_counts = _count_and_gen_subtokens( subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet,
token_counts, alphabet, subtoken_dict, max_subtoken_length) subtoken_dict,
max_subtoken_length)
# Generate new list of subtokens sorted by subtoken count. # Generate new list of subtokens sorted by subtoken count.
subtoken_list, max_subtoken_length = _gen_new_subtoken_list( subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
......
...@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase): ...@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase):
def test_split_string_to_tokens(self): def test_split_string_to_tokens(self):
text = "test? testing 123." text = "test? testing 123."
tokens = tokenizer._split_string_to_tokens(text) tokens = tokenizer._split_string_to_tokens(text,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual(["test", "? ", "testing", "123", "."], tokens) self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
def test_join_tokens_to_string(self): def test_join_tokens_to_string(self):
tokens = ["test", "? ", "testing", "123", "."] tokens = ["test", "? ", "testing", "123", "."]
s = tokenizer._join_tokens_to_string(tokens) s = tokenizer._join_tokens_to_string(tokens,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual("test? testing 123.", s) self.assertEqual("test? testing 123.", s)
def test_escape_token(self): def test_escape_token(self):
...@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase): ...@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase):
escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;" escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
unescaped_token = tokenizer._unescape_token(escaped_token) unescaped_token = tokenizer._unescape_token(escaped_token)
self.assertEqual( self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
"Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
def test_list_to_index_dict(self): def test_list_to_index_dict(self):
lst = ["test", "strings"] lst = ["test", "strings"]
...@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase): ...@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase):
subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3} subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
max_subtoken_length = 2 max_subtoken_length = 2
subtokens = tokenizer._split_token_to_subtokens( subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict,
token, subtoken_dict, max_subtoken_length) max_subtoken_length)
self.assertEqual(["ab", "c"], subtokens) self.assertEqual(["ab", "c"], subtokens)
def test_generate_alphabet_dict(self): def test_generate_alphabet_dict(self):
...@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase): ...@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase):
self.assertIsInstance(subtoken_counts, collections.defaultdict) self.assertIsInstance(subtoken_counts, collections.defaultdict)
self.assertDictEqual( self.assertDictEqual(
{"a": 5, "b": 5, "c": 5, "_": 5, "ab": 5, "bc": 5, "c_": 5, {
"abc": 5, "bc_": 5, "abc_": 5}, subtoken_counts) "a": 5,
"b": 5,
"c": 5,
"_": 5,
"ab": 5,
"bc": 5,
"c_": 5,
"abc": 5,
"bc_": 5,
"abc_": 5
}, subtoken_counts)
def test_filter_and_bucket_subtokens(self): def test_filter_and_bucket_subtokens(self):
subtoken_counts = collections.defaultdict( subtoken_counts = collections.defaultdict(int, {
int, {"a": 2, "b": 4, "c": 1, "ab": 6, "ac": 3, "abbc": 5}) "a": 2,
"b": 4,
"c": 1,
"ab": 6,
"ac": 3,
"abbc": 5
})
min_count = 3 min_count = 3
subtoken_buckets = tokenizer._filter_and_bucket_subtokens( subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
...@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase): ...@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase):
self.assertEqual(set(["abbc"]), subtoken_buckets[4]) self.assertEqual(set(["abbc"]), subtoken_buckets[4])
def test_gen_new_subtoken_list(self): def test_gen_new_subtoken_list(self):
subtoken_counts = collections.defaultdict( subtoken_counts = collections.defaultdict(int, {
int, {"translate": 10, "t": 40, "tr": 16, "tra": 12}) "translate": 10,
"t": 40,
"tr": 16,
"tra": 12
})
min_count = 5 min_count = 5
alphabet = set("translate") alphabet = set("translate")
reserved_tokens = ["reserved", "tokens"] reserved_tokens = ["reserved", "tokens"]
...@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase): ...@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase):
num_iterations = 1 num_iterations = 1
reserved_tokens = ["reserved", "tokens"] reserved_tokens = ["reserved", "tokens"]
vocab_list = tokenizer._generate_subtokens( vocab_list = tokenizer._generate_subtokens(token_counts, alphabet,
token_counts, alphabet, min_count, num_iterations, reserved_tokens) min_count, num_iterations,
reserved_tokens)
# Check that reserved tokens are at the front of the list # Check that reserved tokens are at the front of the list
self.assertEqual(vocab_list[:2], reserved_tokens) self.assertEqual(vocab_list[:2], reserved_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment