Commit 4d6820e0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 292594971
parent bde0f751
...@@ -171,10 +171,11 @@ def whitespace_tokenize(text): ...@@ -171,10 +171,11 @@ def whitespace_tokenize(text):
class FullTokenizer(object): class FullTokenizer(object):
"""Runs end-to-end tokenziation.""" """Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True): def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()} self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, split_on_punc=split_on_punc)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text): def tokenize(self, text):
...@@ -195,13 +196,17 @@ class FullTokenizer(object): ...@@ -195,13 +196,17 @@ class FullTokenizer(object):
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True): def __init__(self, do_lower_case=True, split_on_punc=True):
"""Constructs a BasicTokenizer. """Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. do_lower_case: Whether to lower case the input.
split_on_punc: Whether to apply split on punctuations. By default BERT
starts a new token for punctuations. This makes detokenization difficult
for tasks like seq2seq decoding.
""" """
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.split_on_punc = split_on_punc
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
...@@ -222,7 +227,10 @@ class BasicTokenizer(object): ...@@ -222,7 +227,10 @@ class BasicTokenizer(object):
if self.do_lower_case: if self.do_lower_case:
token = token.lower() token = token.lower()
token = self._run_strip_accents(token) token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token)) if self.split_on_punc:
split_tokens.extend(self._run_split_on_punc(token))
else:
split_tokens.append(token)
output_tokens = whitespace_tokenize(" ".join(split_tokens)) output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens return output_tokens
......
...@@ -77,10 +77,18 @@ class TokenizationTest(tf.test.TestCase): ...@@ -77,10 +77,18 @@ class TokenizationTest(tf.test.TestCase):
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_basic_tokenizer_no_split_on_punc(self):
tokenizer = tokenization.BasicTokenizer(
do_lower_case=True, split_on_punc=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello!how", "are", "you?"])
def test_wordpiece_tokenizer(self): def test_wordpiece_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing" "##ing", "##!", "!"
] ]
vocab = {} vocab = {}
...@@ -94,6 +102,14 @@ class TokenizationTest(tf.test.TestCase): ...@@ -94,6 +102,14 @@ class TokenizationTest(tf.test.TestCase):
tokenizer.tokenize("unwanted running"), tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"]) ["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running !"),
["un", "##want", "##ed", "runn", "##ing", "!"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running!"),
["un", "##want", "##ed", "runn", "##ing", "##!"])
self.assertAllEqual( self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment