"vscode:/vscode.git/clone" did not exist on "9adaa571e3f180c20cfe0134cfb6d9cf4370a3bb"
Commit 4d6820e0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 292594971
parent bde0f751
......@@ -171,10 +171,11 @@ def whitespace_tokenize(text):
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, split_on_punc=split_on_punc)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
......@@ -195,13 +196,17 @@ class FullTokenizer(object):
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
def __init__(self, do_lower_case=True, split_on_punc=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
split_on_punc: Whether to apply split on punctuations. By default BERT
starts a new token for punctuations. This makes detokenization difficult
for tasks like seq2seq decoding.
"""
self.do_lower_case = do_lower_case
self.split_on_punc = split_on_punc
def tokenize(self, text):
"""Tokenizes a piece of text."""
......@@ -222,7 +227,10 @@ class BasicTokenizer(object):
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
if self.split_on_punc:
split_tokens.extend(self._run_split_on_punc(token))
else:
split_tokens.append(token)
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
......
......@@ -77,10 +77,18 @@ class TokenizationTest(tf.test.TestCase):
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_basic_tokenizer_no_split_on_punc(self):
tokenizer = tokenization.BasicTokenizer(
do_lower_case=True, split_on_punc=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello!how", "are", "you?"])
def test_wordpiece_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
"##ing", "##!", "!"
]
vocab = {}
......@@ -94,6 +102,14 @@ class TokenizationTest(tf.test.TestCase):
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running !"),
["un", "##want", "##ed", "runn", "##ing", "!"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running!"),
["un", "##want", "##ed", "runn", "##ing", "##!"])
self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment