Unverified Commit 5c7789d4 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing by correctly raising UnicodeDecodeError. (#13449)

parent 79815090
...@@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer): ...@@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
else: else:
tok_string = bytes([ord(token)]) tok_string = bytes([ord(token)])
bstring += tok_string bstring += tok_string
# XXX: This is most likely incorrect, we want utf-8 errors string = bstring.decode("utf-8")
# to be triggered. However transformers test suite will
# try to decode every ID within the tokenizer on their own
# meaning it will attempt to try and decode invalid utf-8.
# Ignoring errors means passing tests, meanwhile correctly
# raising the errors means editing the automated tests to
# support that behavior (decoding an arbitrary ID might be invalid).
string = bstring.decode("utf-8", errors="ignore")
return string return string
# ByT5Tokenizer has no vocab file # ByT5Tokenizer has no vocab file
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import json import json
import os import os
import re
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Tuple
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
from transformers.file_utils import cached_property, is_tf_available, is_torch_available from transformers.file_utils import cached_property, is_tf_available, is_torch_available
...@@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs) -> ByT5Tokenizer: def get_tokenizer(self, **kwargs) -> ByT5Tokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
# This assumption is invalid for ByT5 because single bytes might not be
# valid utf-8 (byte 128 for instance).
# Here we're overriding the smallest possible method to provide
# a clean sequence without making the same assumption.
toks = []
for i in range(len(tokenizer)):
try:
tok = tokenizer.decode([i], clean_up_tokenization_spaces=False)
except UnicodeDecodeError:
pass
toks.append((i, tok))
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
if min_length is not None and len(toks) < min_length and len(toks) > 0:
while len(toks) < min_length:
toks = toks + toks
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
# Ensure consistency
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
if " " not in output_txt and len(toks_ids) > 1:
output_txt = (
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ " "
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
)
if with_prefix_space:
output_txt = " " + output_txt
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def test_eos_treatment(self): def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"]) batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment