Commit 2670b0d6 authored by Michael Watkins's avatar Michael Watkins Committed by Lysandre Debut
Browse files

Fix bug which lowercases special tokens

parent 35401fe5
...@@ -115,8 +115,10 @@ class CommonTestCases: ...@@ -115,8 +115,10 @@ class CommonTestCases:
def test_added_tokens_do_lower_case(self): def test_added_tokens_do_lower_case(self):
tokenizer = self.get_tokenizer(do_lower_case=True) tokenizer = self.get_tokenizer(do_lower_case=True)
text = "aaaaa bbbbbb low cccccccccdddddddd l" special_token = tokenizer.all_special_tokens[0]
text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token
toks0 = tokenizer.tokenize(text) # toks before adding new_toks toks0 = tokenizer.tokenize(text) # toks before adding new_toks
...@@ -141,7 +143,7 @@ class CommonTestCases: ...@@ -141,7 +143,7 @@ class CommonTestCases:
self.assertEqual(len(toks), len(toks2)) # Length should still be the same self.assertEqual(len(toks), len(toks2)) # Length should still be the same
self.assertNotEqual(len(toks), len(toks0)) self.assertNotEqual(len(toks), len(toks0))
self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ self.assertNotEqual(toks[1], toks2[1]) # But at least the first non-special tokens should differ
def test_add_tokens_tokenizer(self): def test_add_tokens_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
...@@ -22,6 +22,7 @@ import json ...@@ -22,6 +22,7 @@ import json
import six import six
import copy import copy
import itertools import itertools
import re
from io import open from io import open
from .file_utils import cached_path, is_tf_available, is_torch_available from .file_utils import cached_path, is_tf_available, is_torch_available
...@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object): ...@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False): if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
token = token.lower() token = token.lower()
if token != self.unk_token and \ if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
...@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object): ...@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
Take care of added tokens. Take care of added tokens.
""" """
def lowercase_text(t):
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
r'(.+?)'
return re.sub(
pattern,
lambda m: m.groups()[0] or m.groups()[1].lower(),
t)
if self.init_kwargs.get('do_lower_case', False): if self.init_kwargs.get('do_lower_case', False):
text = text.lower() text = lowercase_text(text)
def split_on_token(tok, text): def split_on_token(tok, text):
result = [] result = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment