Unverified Commit cdf1f7ed authored by Junyuan Zheng's avatar Junyuan Zheng Committed by GitHub
Browse files

Fix tokenizer saving and loading error (#6026)



* fix tokenizer saving and loading bugs when adding AddedToken to additional special tokens

* Add tokenizer test

* Style

* Style 2
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 83984a61
...@@ -1562,6 +1562,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1562,6 +1562,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
for key, value in special_tokens_map.items(): for key, value in special_tokens_map.items():
if isinstance(value, dict): if isinstance(value, dict):
value = AddedToken(**value) value = AddedToken(**value)
elif isinstance(value, list):
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
setattr(tokenizer, key, value) setattr(tokenizer, key, value)
# Add supplementary tokens. # Add supplementary tokens.
...@@ -1633,6 +1635,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1633,6 +1635,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
for key, value in self.special_tokens_map_extended.items(): for key, value in self.special_tokens_map_extended.items():
if isinstance(value, AddedToken): if isinstance(value, AddedToken):
write_dict[key] = value.__getstate__() write_dict[key] = value.__getstate__()
elif isinstance(value, list):
write_dict[key] = [
token.__getstate__() if isinstance(token, AddedToken) else token for token in value
]
else: else:
write_dict[key] = value write_dict[key] = value
f.write(json.dumps(write_dict, ensure_ascii=False)) f.write(json.dumps(write_dict, ensure_ascii=False))
......
...@@ -1165,6 +1165,16 @@ class TokenizerTesterMixin: ...@@ -1165,6 +1165,16 @@ class TokenizerTesterMixin:
encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key], encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key],
) )
def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
new_token = AddedToken("new_token", lstrip=True)
tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer.from_pretrained(tmp_dir_name)
def test_batch_encode_plus_padding(self): def test_batch_encode_plus_padding(self):
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus # Test that padded sequences are equivalent between batch_encode_plus and encode_plus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment