Unverified Commit deba7655 authored by Ita Zaporozhets's avatar Ita Zaporozhets Committed by GitHub
Browse files

Add split special tokens (#30772)



* seems like `split_special_tokens` is used here

* split special token

* add new line at end of file

* moving split special token test to common tests

* added assertions

* test

* fixup

* add co-author

* passing rest of args to gptsan_japanese, fixing tests

* removing direct comparison of fast and slow models

* adding test support for UDOP and LayoutXLM

* ruff fix

* readd check if slow tokenizer

* modify test to handle bos tokens

* removing commented function

* trigger build

* applying review feedback - updated docstrings, var names, and simplified tests

* ruff fixes

* Update tests/test_tokenization_common.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* applying feedback, comments

* shutil temp directory fix

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarIta Zaporozhets <itazaporozhets@Itas-MBP.localdomain>
Co-authored-by: default avataritazap <itazap@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarIta Zaporozhets <itazaporozhets@Itas-MacBook-Pro.local>
parent e5103a76
...@@ -353,6 +353,7 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer): ...@@ -353,6 +353,7 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
**kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
# This tokenizer converts input text pairs into Prefix input and subsequent input # This tokenizer converts input text pairs into Prefix input and subsequent input
if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list): if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list):
...@@ -379,6 +380,7 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer): ...@@ -379,6 +380,7 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
return_offsets_mapping, return_offsets_mapping,
return_length, return_length,
verbose, verbose,
**kwargs,
) )
......
...@@ -415,6 +415,11 @@ class LayoutXLMTokenizerFast(PreTrainedTokenizerFast): ...@@ -415,6 +415,11 @@ class LayoutXLMTokenizerFast(PreTrainedTokenizerFast):
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
batched_input = [(text, pair)] if pair else [text] batched_input = [(text, pair)] if pair else [text]
self._tokenizer.encode_special_tokens = kwargs.pop(
"split_special_tokens", self._tokenizer.encode_special_tokens
)
encodings = self._tokenizer.encode_batch( encodings = self._tokenizer.encode_batch(
batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
) )
......
...@@ -425,6 +425,11 @@ class UdopTokenizerFast(PreTrainedTokenizerFast): ...@@ -425,6 +425,11 @@ class UdopTokenizerFast(PreTrainedTokenizerFast):
# Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.tokenize # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.tokenize
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
batched_input = [(text, pair)] if pair else [text] batched_input = [(text, pair)] if pair else [text]
self._tokenizer.encode_special_tokens = kwargs.pop(
"split_special_tokens", self._tokenizer.encode_special_tokens
)
encodings = self._tokenizer.encode_batch( encodings = self._tokenizer.encode_batch(
batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
) )
......
...@@ -764,6 +764,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -764,6 +764,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
def get_input_ids(text): def get_input_ids(text):
...@@ -820,6 +821,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -820,6 +821,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_length=return_length, return_length=return_length,
return_tensors=return_tensors, return_tensors=return_tensors,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
) )
return BatchEncoding(batch_outputs) return BatchEncoding(batch_outputs)
...@@ -841,6 +843,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -841,6 +843,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
) -> BatchEncoding: ) -> BatchEncoding:
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
...@@ -870,6 +873,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -870,6 +873,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_tensors=None, # We convert the whole batch to tensors at the end return_tensors=None, # We convert the whole batch to tensors at the end
prepend_batch_axis=False, prepend_batch_axis=False,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
) )
for key, value in outputs.items(): for key, value in outputs.items():
......
...@@ -1538,10 +1538,10 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1538,10 +1538,10 @@ INIT_TOKENIZER_DOCSTRING = r"""
Whether or not the model should cleanup the spaces that were added when splitting the input text during the Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. tokenization process.
split_special_tokens (`bool`, *optional*, defaults to `False`): split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is Whether or not the special tokens should be split during the tokenization process. Passing will affect the
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") = internal state of the tokenizer. The default behavior is to not split special tokens. This means that if
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<', `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") = ['<s>`]. Otherwise, if
's', '>']`. This argument is only supported for `slow` tokenizers for the moment. `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<','s', '>']`.
""" """
...@@ -2876,6 +2876,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2876,6 +2876,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"return_special_tokens_mask": return_special_tokens_mask, "return_special_tokens_mask": return_special_tokens_mask,
"return_offsets_mapping": return_offsets_mapping, "return_offsets_mapping": return_offsets_mapping,
"return_length": return_length, "return_length": return_length,
"split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens),
"verbose": verbose, "verbose": verbose,
} }
all_kwargs.update(kwargs) all_kwargs.update(kwargs)
...@@ -2920,6 +2921,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2920,6 +2921,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
# Input type checking for clearer error # Input type checking for clearer error
...@@ -2989,6 +2991,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2989,6 +2991,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length, return_length=return_length,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
**kwargs, **kwargs,
) )
else: else:
...@@ -3010,6 +3013,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3010,6 +3013,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length, return_length=return_length,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
**kwargs, **kwargs,
) )
...@@ -3083,6 +3087,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3083,6 +3087,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length, return_length=return_length,
verbose=verbose, verbose=verbose,
split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens),
**kwargs, **kwargs,
) )
...@@ -3105,6 +3110,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3105,6 +3110,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
raise NotImplementedError raise NotImplementedError
...@@ -3135,6 +3141,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3135,6 +3141,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
""" """
...@@ -3180,6 +3187,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3180,6 +3187,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length, return_length=return_length,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
**kwargs, **kwargs,
) )
...@@ -3208,6 +3216,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3208,6 +3216,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
raise NotImplementedError raise NotImplementedError
......
...@@ -163,6 +163,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -163,6 +163,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
# We call this after having initialized the backend tokenizer because we update it. # We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs) super().__init__(**kwargs)
# Set the splitting mode for special tokens for the tokenizer to be used throughout the class.
self._tokenizer.encode_special_tokens = self.split_special_tokens
# The following logic will be replace with a single add_tokens once a fix is pushed to tokenizers # The following logic will be replace with a single add_tokens once a fix is pushed to tokenizers
# allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens # allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens
# uses the information stored in `added_tokens_decoder`. # uses the information stored in `added_tokens_decoder`.
...@@ -494,6 +497,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -494,6 +497,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
) -> BatchEncoding: ) -> BatchEncoding:
if not isinstance(batch_text_or_text_pairs, (tuple, list)): if not isinstance(batch_text_or_text_pairs, (tuple, list)):
raise TypeError( raise TypeError(
...@@ -509,6 +513,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -509,6 +513,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
) )
if self._tokenizer.encode_special_tokens != split_special_tokens:
self._tokenizer.encode_special_tokens = split_special_tokens
encodings = self._tokenizer.encode_batch( encodings = self._tokenizer.encode_batch(
batch_text_or_text_pairs, batch_text_or_text_pairs,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -578,6 +585,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -578,6 +585,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
split_special_tokens: bool = False,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
batched_input = [(text, text_pair)] if text_pair else [text] batched_input = [(text, text_pair)] if text_pair else [text]
...@@ -598,6 +606,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -598,6 +606,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length, return_length=return_length,
verbose=verbose, verbose=verbose,
split_special_tokens=split_special_tokens,
**kwargs, **kwargs,
) )
......
...@@ -150,17 +150,40 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -150,17 +150,40 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_split_special_tokens(self): def test_split_special_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
_, _, boxes = self.get_question_words_and_boxes() special_token = "<my_new_token>"
special_token = "[SPECIAL_TOKEN]" special_sentence = f"Hey this is a {special_token} token"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) _, _, boxes = self.get_question_words_and_boxes()
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1) with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_rust = self.rust_tokenizer_class.from_pretrained(
encoded_split_special_token = tokenizer.tokenize( pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes )
) tokenizer_py = self.tokenizer_class.from_pretrained(
self.assertTrue(len(encoded_split_special_token) > 1) pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
)
py_tokens_output = tokenizer_py.tokenize(special_sentence)
rust_tokens_output = tokenizer_rust.tokenize(special_sentence)
self.assertTrue(special_token not in py_tokens_output)
self.assertTrue(special_token not in rust_tokens_output)
py_tokens_output_unsplit = tokenizer_py.tokenize(special_sentence, split_special_tokens=False)
rust_tokens_output_unsplit = tokenizer_rust.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in py_tokens_output_unsplit)
self.assertTrue(special_token in rust_tokens_output_unsplit)
tmpdirname = tempfile.mkdtemp()
tokenizer_py.save_pretrained(tmpdirname)
fast_from_saved = self.tokenizer_class.from_pretrained(tmpdirname)
output_tokens_reloaded_split = fast_from_saved.tokenize(special_sentence)
self.assertTrue(special_token not in output_tokens_reloaded_split)
output_tokens_reloaded_unsplit = fast_from_saved.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in output_tokens_reloaded_unsplit)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
......
...@@ -1921,3 +1921,48 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1921,3 +1921,48 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
excepted_decoding = "<pad> paragraph<loc_58><loc_34><loc_446><loc_449></s>" excepted_decoding = "<pad> paragraph<loc_58><loc_34><loc_446><loc_449></s>"
assert decoding == excepted_decoding assert decoding == excepted_decoding
def test_split_special_tokens(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "<my_new_token>"
special_sentence = f"Hey this is a {special_token} token"
_, _, boxes = self.get_question_words_and_boxes()
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_rust = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
)
tokenizer_py = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
)
special_token_id = tokenizer_py.convert_tokens_to_ids(special_token)
encoded_special_token_unsplit = tokenizer_py.encode(
special_token, add_special_tokens=False, split_special_tokens=False
)
self.assertTrue(special_token_id in encoded_special_token_unsplit)
encoded_special_token_split = tokenizer_py.encode(special_token, add_special_tokens=False)
self.assertTrue(special_token_id not in encoded_special_token_split)
py_tokens_output = tokenizer_py.tokenize(special_sentence)
rust_tokens_output = tokenizer_rust.tokenize(special_sentence)
self.assertTrue(special_token not in py_tokens_output)
self.assertTrue(special_token not in rust_tokens_output)
py_tokens_output_unsplit = tokenizer_py.tokenize(special_sentence, split_special_tokens=False)
rust_tokens_output_unsplit = tokenizer_rust.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in py_tokens_output_unsplit)
self.assertTrue(special_token in rust_tokens_output_unsplit)
tmpdirname = tempfile.mkdtemp()
tokenizer_py.save_pretrained(tmpdirname)
fast_from_saved = self.tokenizer_class.from_pretrained(tmpdirname)
output_tokens_reloaded_split = fast_from_saved.tokenize(special_sentence)
self.assertTrue(special_token not in output_tokens_reloaded_split)
output_tokens_reloaded_unsplit = fast_from_saved.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in output_tokens_reloaded_unsplit)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
import itertools import itertools
import json import json
...@@ -4168,34 +4167,59 @@ class TokenizerTesterMixin: ...@@ -4168,34 +4167,59 @@ class TokenizerTesterMixin:
def test_split_special_tokens(self): def test_split_special_tokens(self):
if not self.test_slow_tokenizer: if not self.test_slow_tokenizer:
return return
# Tests the expected appearance (or absence) of special token in encoded output,
# explicit values are not tested because tokenization is model dependent and can change
for tokenizer, pretrained_name, kwargs in self.tokenizers_list: for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]" special_token = "<my_new_token>"
special_sentence = f"Hey this is a {special_token} token"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) tokenizer_rust = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
if not tokenizer.is_fast: )
# bloom, gptneox etc only have a fast tokenizer_py = self.tokenizer_class.from_pretrained(
tokenizer.add_special_tokens( pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
{ )
"additional_special_tokens": [
AddedToken(special_token, rstrip=True, lstrip=True, normalized=True, special=True)
]
}
)
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.encode( special_token_id = tokenizer_py.convert_tokens_to_ids(special_token)
special_token, add_special_tokens=False, split_special_tokens=True encoded_special_token_unsplit = tokenizer_py.encode(
) special_token, add_special_tokens=False, split_special_tokens=False
if len(encoded_split_special_token) == 1: )
# if we have subword tokenization or special vocab self.assertTrue(special_token_id in encoded_special_token_unsplit)
self.assertTrue(
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token) encoded_special_token_split = tokenizer_py.encode(special_token, add_special_tokens=False)
) self.assertTrue(special_token_id not in encoded_special_token_split)
else:
self.assertTrue(len(encoded_split_special_token) > 1) py_tokens_output = tokenizer_py.tokenize(special_sentence)
rust_tokens_output = tokenizer_rust.tokenize(special_sentence)
self.assertTrue(special_token not in py_tokens_output)
self.assertTrue(special_token not in rust_tokens_output)
py_tokens_output_unsplit = tokenizer_py.tokenize(special_sentence, split_special_tokens=False)
rust_tokens_output_unsplit = tokenizer_rust.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in py_tokens_output_unsplit)
self.assertTrue(special_token in rust_tokens_output_unsplit)
py_tokens_output = tokenizer_py(special_sentence)
rust_tokens_output = tokenizer_rust(special_sentence)
self.assertTrue(special_token_id not in py_tokens_output)
self.assertTrue(special_token_id not in rust_tokens_output)
tmp_dir = tempfile.mkdtemp()
try:
tokenizer_py.save_pretrained(tmp_dir)
fast_from_saved = self.tokenizer_class.from_pretrained(tmp_dir)
finally:
shutil.rmtree(tmp_dir)
output_tokens_reloaded_split = fast_from_saved.tokenize(special_sentence)
self.assertTrue(special_token not in output_tokens_reloaded_split)
output_tokens_reloaded_unsplit = fast_from_saved.tokenize(special_sentence, split_special_tokens=False)
self.assertTrue(special_token in output_tokens_reloaded_unsplit)
def test_added_tokens_serialization(self): def test_added_tokens_serialization(self):
# Utility to test the added vocab # Utility to test the added vocab
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment