Commit 751beb9e authored by WrRan's avatar WrRan
Browse files

never split some text

parent 2e4db64c
...@@ -75,7 +75,8 @@ def whitespace_tokenize(text): ...@@ -75,7 +75,8 @@ def whitespace_tokenize(text):
class BertTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None): def __init__(self, vocab_file, do_lower_case=True, max_len=None,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -83,7 +84,8 @@ class BertTokenizer(object): ...@@ -83,7 +84,8 @@ class BertTokenizer(object):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
...@@ -156,13 +158,16 @@ class BertTokenizer(object): ...@@ -156,13 +158,16 @@ class BertTokenizer(object):
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True): def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer. """Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. do_lower_case: Whether to lower case the input.
""" """
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
...@@ -198,6 +203,8 @@ class BasicTokenizer(object): ...@@ -198,6 +203,8 @@ class BasicTokenizer(object):
def _run_split_on_punc(self, text): def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text.""" """Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text) chars = list(text)
i = 0 i = 0
start_new_word = True start_new_word = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment