Unverified Commit 5c7789d4 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing by correctly raising UnicodeDecodeError. (#13449)

parent 79815090
......@@ -237,14 +237,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
else:
tok_string = bytes([ord(token)])
bstring += tok_string
# XXX: This is most likely incorrect, we want utf-8 errors
# to be triggered. However transformers test suite will
# try to decode every ID within the tokenizer on their own
# meaning it will attempt to try and decode invalid utf-8.
# Ignoring errors means passing tests, meanwhile correctly
# raising the errors means editing the automated tests to
# support that behavior (decoding an arbitrary ID might be invalid).
string = bstring.decode("utf-8", errors="ignore")
string = bstring.decode("utf-8")
return string
# ByT5Tokenizer has no vocab file
......
......@@ -15,9 +15,11 @@
import json
import os
import re
import shutil
import tempfile
import unittest
from typing import Tuple
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
......@@ -50,6 +52,44 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs) -> ByT5Tokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
# This assumption is invalid for ByT5 because single bytes might not be
# valid utf-8 (byte 128 for instance).
# Here we're overriding the smallest possible method to provide
# a clean sequence without making the same assumption.
toks = []
for i in range(len(tokenizer)):
try:
tok = tokenizer.decode([i], clean_up_tokenization_spaces=False)
except UnicodeDecodeError:
pass
toks.append((i, tok))
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
if min_length is not None and len(toks) < min_length and len(toks) > 0:
while len(toks) < min_length:
toks = toks + toks
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
# Ensure consistency
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
if " " not in output_txt and len(toks_ids) > 1:
output_txt = (
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ " "
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
)
if with_prefix_space:
output_txt = " " + output_txt
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment