Unverified Commit 716df5fb authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[ `TokenizationUtils`] Fix `add_special_tokens` when the token is already there (#28520)



* fix adding special tokens when the token is already there.

* add a test

* add a test

* nit

* fix the test: make sure the order is preserved

* Update tests/test_tokenization_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 07ae53e6
...@@ -943,14 +943,15 @@ class SpecialTokensMixin: ...@@ -943,14 +943,15 @@ class SpecialTokensMixin:
isinstance(t, (str, AddedToken)) for t in value isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances" ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
to_add = set() to_add = []
for token in value: for token in value:
if isinstance(token, str): if isinstance(token, str):
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
if str(token) not in self.additional_special_tokens: if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
to_add.add(token) continue
if replace_additional_special_tokens: to_add.append(token)
if replace_additional_special_tokens and len(to_add) > 0:
setattr(self, key, list(to_add)) setattr(self, key, list(to_add))
else: else:
self._additional_special_tokens.extend(to_add) self._additional_special_tokens.extend(to_add)
......
...@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin: ...@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
_test_added_vocab_and_eos( _test_added_vocab_and_eos(
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4 EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
) )
def test_special_token_addition(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
# Create tokenizer and add an additional special token
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer_1.save_pretrained(tmp_dir)
# Load the above tokenizer and add the same special token a second time
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
tokenizer_2.add_special_tokens(
{"additional_special_tokens": ["<tok>"]},
replace_additional_special_tokens=False,
)
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment