Commit 78cf7b4a authored by Patrick Lewis's avatar Patrick Lewis
Browse files

added code to raise value error for bert tokenizer for covert_tokens_to_indices

parent 786cc412
...@@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { ...@@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
'bert-large-cased': 512,
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
}
VOCAB_NAME = 'vocab.txt' VOCAB_NAME = 'vocab.txt'
...@@ -65,7 +74,8 @@ def whitespace_tokenize(text): ...@@ -65,7 +74,8 @@ def whitespace_tokenize(text):
class BertTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True):
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -75,6 +85,7 @@ class BertTokenizer(object): ...@@ -75,6 +85,7 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
split_tokens = [] split_tokens = []
...@@ -88,6 +99,12 @@ class BertTokenizer(object): ...@@ -88,6 +99,12 @@ class BertTokenizer(object):
ids = [] ids = []
for token in tokens: for token in tokens:
ids.append(self.vocab[token]) ids.append(self.vocab[token])
if len(ids) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids return ids
def convert_ids_to_tokens(self, ids): def convert_ids_to_tokens(self, ids):
...@@ -126,6 +143,11 @@ class BertTokenizer(object): ...@@ -126,6 +143,11 @@ class BertTokenizer(object):
else: else:
logger.info("loading vocabulary file {} from cache at {}".format( logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file)) vocab_file, resolved_vocab_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer return tokenizer
......
...@@ -44,6 +44,24 @@ class TokenizationTest(unittest.TestCase): ...@@ -44,6 +44,24 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_full_tokenizer_raises_error_for_long_sequences(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
tokenizer = BertTokenizer(vocab_file, max_len=10)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
indices = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(indices, [0 for _ in range(10)])
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment