Unverified Commit 30b3c46f authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`split_special_tokens`] Add support for `split_special_tokens` argument to encode (#25081)

* draft changes

* update and add tests

* styling for no

* move test

* path to usable model

* update test

* small update

* update bertbased tokenizers

* don'tuse kwargs for _tokenize

* don'tuse kwargs for _tokenize

* fix copies

* update

* update test for special tokenizers

* fixup

* skip two tests

* remove pdb breakpiont()

* wowo

* rewrite custom tests

* nits

* revert chang in target keys

* fix markup lm

* update documentation of the argument
parent 9d7afd25
...@@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer): ...@@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer): ...@@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer): ...@@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer): ...@@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer): ...@@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer): ...@@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer): ...@@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer): ...@@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer): ...@@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer): ...@@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
def get_vocab(self): def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text, split_special_tokens=False):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set # If the token is part of the never_split set
if token in self.basic_tokenizer.never_split: if token in self.basic_tokenizer.never_split:
split_tokens.append(token) split_tokens.append(token)
......
...@@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
all_special_tokens_extended = { all_special_tokens_extended = {
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken) str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
} }
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
text, kwargs = self.prepare_for_tokenization(text, **kwargs) text, kwargs = self.prepare_for_tokenization(text, **kwargs)
...@@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
# split_special_tokens: empty `no_split_token`
if split_special_tokens:
no_split_token = []
tokens = [text]
else:
no_split_token = set(self.unique_no_split_tokens) no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text) tokens = self.tokens_trie.split(text)
# ["This is something", "<special_token_1>", " else"] # ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if token in no_split_token: if token in no_split_token:
......
...@@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r"""
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. tokenization process.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
's', '>']`. This argument is only supported for `slow` tokenizers for the moment.
""" """
...@@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# By default, cleaning tokenization spaces for both fast and slow tokenizers # By default, cleaning tokenization spaces for both fast and slow tokenizers
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
# By default, do not split special tokens for both fast and slow tokenizers
self.split_special_tokens = kwargs.pop("split_special_tokens", False)
self.deprecation_warnings = ( self.deprecation_warnings = (
{} {}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging). ) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
......
...@@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_right_and_left_truncation(self): def test_right_and_left_truncation(self):
pass pass
@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass
def test_encode_plus_with_padding(self): def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
......
...@@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_right_and_left_truncation(self): def test_right_and_left_truncation(self):
pass pass
@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass
def test_encode_plus_with_padding(self): def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
......
...@@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_split_special_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
_, _, boxes = self.get_question_words_and_boxes()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes
)
self.assertTrue(len(encoded_split_special_token) > 1)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
......
...@@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertTrue(special_token_id in p_output) self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output) self.assertTrue(special_token_id in cr_output)
def test_split_special_tokens(self):
# TODO this is only possible for slow currently
tokenizer = self.get_tokenizer()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True
)
self.assertTrue(len(encoded_split_special_token) > 1)
def test_training_new_tokenizer(self): def test_training_new_tokenizer(self):
# This feature only exists for fast tokenizers # This feature only exists for fast tokenizers
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
......
...@@ -3909,6 +3909,7 @@ class TokenizerTesterMixin: ...@@ -3909,6 +3909,7 @@ class TokenizerTesterMixin:
# Should not raise an error # Should not raise an error
self.rust_tokenizer_class.from_pretrained(tmp_dir_2) self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
# TODO This is ran for all models but only tests bert...
def test_clean_up_tokenization_spaces(self): def test_clean_up_tokenization_spaces(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
assert tokenizer.clean_up_tokenization_spaces is True assert tokenizer.clean_up_tokenization_spaces is True
...@@ -3953,3 +3954,29 @@ class TokenizerTesterMixin: ...@@ -3953,3 +3954,29 @@ class TokenizerTesterMixin:
tokenizer.clean_up_tokenization_spaces = True tokenizer.clean_up_tokenization_spaces = True
decoded = tokenizer.decode(tokens) decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]" assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
def test_split_special_tokens(self):
if not self.test_slow_tokenizer:
return
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
if not tokenizer.is_fast:
# bloom, gptneox etc only have a fast
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.encode(
special_token, add_special_tokens=False, split_special_tokens=True
)
if len(encoded_split_special_token) == 1:
# if we have subword tokenization or special vocab
self.assertTrue(
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
)
else:
self.assertTrue(len(encoded_split_special_token) > 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment