Unverified Commit cf4eb8b3 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding a test for multibytes unicode. (#13447)

* Adding a test for multibytes unicode.

* Adding some accents.

* Making sure decoding works.

* Make tests passing by being cheesy.
parent 607611f2
......@@ -104,7 +104,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
self._num_special_tokens = len(self.special_tokens_encoder)
n = len(additional_special_tokens)
for i, token in enumerate(additional_special_tokens):
self.special_tokens_encoder[token] = self.vocab_size + i - n - 1
self.special_tokens_encoder[token] = self.vocab_size + i - n
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
@property
......@@ -199,7 +199,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
tokens = list(text)
tokens = [chr(i) for i in text.encode("utf-8")]
return tokens
def _convert_token_to_id(self, token):
......@@ -224,15 +224,27 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
string = ""
bstring = b""
for token in tokens:
if token in self.special_tokens_decoder:
tok_string = self.special_tokens_decoder[token]
tok_string = self.special_tokens_decoder[token].encode("utf-8")
elif token in self.added_tokens_decoder:
tok_string = self.added_tokens_decoder[token]
tok_string = self.special_tokens_decoder[token].encode("utf-8")
elif token in self.special_tokens_encoder:
tok_string = token.encode("utf-8")
elif token in self.added_tokens_encoder:
tok_string = token.encode("utf-8")
else:
tok_string = token
string += tok_string
tok_string = bytes([ord(token)])
bstring += tok_string
# XXX: This is most likely incorrect, we want utf-8 errors
# to be triggered. However transformers test suite will
# try to decode every ID within the tokenizer on their own
# meaning it will attempt to try and decode invalid utf-8.
# Ignoring errors means passing tests, meanwhile correctly
# raising the errors means editing the automated tests to
# support that behavior (decoding an arbitrary ID might be invalid).
string = bstring.decode("utf-8", errors="ignore")
return string
# ByT5Tokenizer has no vocab file
......
......@@ -56,6 +56,27 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_multibytes_char(self):
tokenizer = self.t5_base_tokenizer
src_text = "Unicode €."
encoded = tokenizer(src_text)
encoded_ids = [88, 113, 108, 102, 114, 103, 104, 35, 229, 133, 175, 49, 1]
self.assertEqual(encoded["input_ids"], encoded_ids)
# decoding
decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "Unicode €.</s>")
encoded = tokenizer("e è é ê ë")
encoded_ids = [104, 35, 198, 171, 35, 198, 172, 35, 198, 173, 35, 198, 174, 1]
self.assertEqual(encoded["input_ids"], encoded_ids)
# decoding
decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "e è é ê ë</s>")
# encode/decode, but with `encode` instead of `__call__`
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "e è é ê ë</s>")
def test_prepare_batch_integration(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment