Unverified Commit 9f9ddcc2 authored by Ben Eyal's avatar Ben Eyal Committed by GitHub
Browse files

🚨 🚨 🚨 Fix Issue 15003: SentencePiece Tokenizers Not Adding Special Tokens in...

🚨 🚨 🚨 Fix Issue 15003: SentencePiece Tokenizers Not Adding Special Tokens in `convert_tokens_to_string` (#15775)

* Add test for SentencePiece not adding special tokens to strings

* Add SentencePieceStringConversionMixin to fix issue 15003

* Fix conversion from tokens to string for most SentencePiece tokenizers

Tokenizers fixed:
- AlbertTokenizer
- BarthezTokenizer
- CamembertTokenizer
- FNetTokenizer
- M2M100Tokenizer
- MBart50Tokenizer
- PegasusTokenizer
- Speech2TextTokenizer

* Fix MarianTokenizer, adjust SentencePiece test to accomodate vocab

* Fix DebertaV2Tokenizer

* Ignore LayoutXLMTokenizer in SentencePiece string conversion test

* Run 'make style' and 'make quality'

* Clean convert_tokens_to_string test

Instead of explicitly ignoring LayoutXLMTokenizer in the test,
override the test in LayoutLMTokenizationTest and do nothing in it.

* Remove commented out code

* Improve robustness of convert_tokens_to_string test

Instead of comparing lengths of re-tokenized text and input_ids,
check that converting all special tokens to string yields a string
with all special tokens.

* Inline and remove SentencePieceStringConversionMixin

The convert_tokens_to_string method is now implemented
in each relevant SentencePiece tokenizer.

* Run 'make style' and 'make quality'

* Revert removal of space in convert_tokens_to_string

* Remove redundant import

* Revert test text to original

* Uncomment the lowercasing of the reverse_text variable

* Mimic Rust tokenizer behavior for tokenizers

- Albert
- Barthez
- Camembert
- MBart50
- T5

* Fix accidentally skipping test in wrong tokenizer

* Add test for equivalent Rust and slow tokenizer behavior

* Override _decode in BigBirdTokenizer to mimic Rust behavior

* Override _decode in FNetTokenizer to mimic Rust behavior

* Override _decode in XLNetTokenizer to mimic Rust behavior

* Remove unused 're' import

* Update DebertaV2Tokenizer to mimic Rust tokenizer

* Deberta tokenizer now behaves like Albert and its `convert_tokens_to_string` is not tested.

* Ignore problematic tests in Deberta V2

* Add comment on why the Deberta V2 tests are skipped
parent fb7cbe23
...@@ -250,7 +250,23 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -250,7 +250,23 @@ class AlbertTokenizer(PreTrainedTokenizer):
return self.sp_model.IdToPiece(index) return self.sp_model.IdToPiece(index)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens) """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
...@@ -263,6 +263,25 @@ class BarthezTokenizer(PreTrainedTokenizer): ...@@ -263,6 +263,25 @@ class BarthezTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index) return self.sp_model.IdToPiece(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
...@@ -278,10 +297,6 @@ class BarthezTokenizer(PreTrainedTokenizer): ...@@ -278,10 +297,6 @@ class BarthezTokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
...@@ -151,8 +151,17 @@ class BertGenerationTokenizer(PreTrainedTokenizer): ...@@ -151,8 +151,17 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) current_sub_tokens = []
return out_string out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os import os
import re
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -182,8 +183,65 @@ class BigBirdTokenizer(PreTrainedTokenizer): ...@@ -182,8 +183,65 @@ class BigBirdTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) current_sub_tokens = []
return out_string out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
# Mimic the behavior of the Rust tokenizer:
# No space before [MASK] and [SEP]
if spaces_between_special_tokens:
text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -261,6 +261,25 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -261,6 +261,25 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
...@@ -276,10 +295,6 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -276,10 +295,6 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
return self.sp_model.decode(tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
...@@ -146,7 +146,9 @@ class DebertaV2Tokenizer(PreTrainedTokenizer): ...@@ -146,7 +146,9 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct self.split_by_punct = split_by_punct
self.vocab_file = vocab_file self.vocab_file = vocab_file
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs) self._tokenizer = SPMTokenizer(
vocab_file, self.all_special_tokens, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
)
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -291,7 +293,9 @@ class SPMTokenizer: ...@@ -291,7 +293,9 @@ class SPMTokenizer:
BPE-dropout. BPE-dropout.
""" """
def __init__(self, vocab_file, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None): def __init__(
self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
):
self.split_by_punct = split_by_punct self.split_by_punct = split_by_punct
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
...@@ -312,6 +316,7 @@ class SPMTokenizer: ...@@ -312,6 +316,7 @@ class SPMTokenizer:
# self.vocab['[UNK]'] = 3 # self.vocab['[UNK]'] = 3
self.spm = spm self.spm = spm
self.special_tokens = special_tokens
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
...@@ -339,7 +344,22 @@ class SPMTokenizer: ...@@ -339,7 +344,22 @@ class SPMTokenizer:
def decode(self, tokens, start=-1, end=-1, raw_text=None): def decode(self, tokens, start=-1, end=-1, raw_text=None):
if raw_text is None: if raw_text is None:
return self.spm.decode_pieces([t for t in tokens]) current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.spm.decode_pieces(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.spm.decode_pieces(current_sub_tokens)
return out_string.strip()
else: else:
words = self.split_to_words(raw_text) words = self.split_to_words(raw_text)
word_tokens = [self.tokenize(w) for w in words] word_tokens = [self.tokenize(w) for w in words]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Tokenization classes for FNet model.""" """ Tokenization classes for FNet model."""
import os import os
import re
import unicodedata import unicodedata
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -213,7 +214,66 @@ class FNetTokenizer(PreTrainedTokenizer): ...@@ -213,7 +214,66 @@ class FNetTokenizer(PreTrainedTokenizer):
return self.sp_model.IdToPiece(index) return self.sp_model.IdToPiece(index)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens) """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
# Mimic the behavior of the Rust tokenizer:
# No space after <unk>
if spaces_between_special_tokens:
text = re.sub(r"(<unk>) ", r"\1", " ".join(sub_texts))
else:
text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
...@@ -218,9 +218,19 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -218,9 +218,19 @@ class M2M100Tokenizer(PreTrainedTokenizer):
return self.id_to_lang_token[index] return self.id_to_lang_token[index]
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
return self.sp_model.decode(tokens) current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
......
...@@ -265,10 +265,18 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -265,10 +265,18 @@ class MarianTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise""" """Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer: sp_model = self.spm_source if self._decode_use_source_tokenizer else self.spm_target
return self.spm_source.DecodePieces(tokens) current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else: else:
return self.spm_target.DecodePieces(tokens) current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id.""" """Build model inputs from a sequence by appending eos_token_id."""
......
...@@ -232,9 +232,24 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -232,9 +232,24 @@ class MBart50Tokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
return self.sp_model.decode(tokens) current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -231,8 +231,17 @@ class PegasusTokenizer(PreTrainedTokenizer): ...@@ -231,8 +231,17 @@ class PegasusTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) current_sub_tokens = []
return out_string out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def num_special_tokens_to_add(self, pair=False): def num_special_tokens_to_add(self, pair=False):
"""Just EOS""" """Just EOS"""
......
...@@ -158,8 +158,17 @@ class ReformerTokenizer(PreTrainedTokenizer): ...@@ -158,8 +158,17 @@ class ReformerTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) current_sub_tokens = []
return out_string out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -190,11 +190,19 @@ class Speech2TextTokenizer(PreTrainedTokenizer): ...@@ -190,11 +190,19 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = self.sp_model.decode(tokens) current_sub_tokens = []
out_string = ""
if self.do_upper_case: for token in tokens:
out_string = out_string.upper() # make sure that special tokens are not decoded using sentencepiece model
return out_string if token in self.all_special_tokens:
decoded = self.sp_model.decode(current_sub_tokens)
out_string += (decoded.upper() if self.do_upper_case else decoded) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
decoded = self.sp_model.decode(current_sub_tokens)
out_string += decoded.upper() if self.do_upper_case else decoded
return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id.""" """Build model inputs from a sequence by appending eos_token_id."""
......
...@@ -311,14 +311,19 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -311,14 +311,19 @@ class T5Tokenizer(PreTrainedTokenizer):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = [] current_sub_tokens = []
out_string = "" out_string = ""
prev_is_special = False
for token in tokens: for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model # make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens: if token in self.all_special_tokens:
out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " " if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = [] current_sub_tokens = []
else: else:
current_sub_tokens.append(token) current_sub_tokens.append(token)
out_string += self.sp_model.decode_pieces(current_sub_tokens) prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip() return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
......
...@@ -250,6 +250,46 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -250,6 +250,46 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
# Mimic the behavior of the Rust tokenizer:
# By default, there are no spaces between special tokens
text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
......
...@@ -37,7 +37,7 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -37,7 +37,7 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
super().setUp() super().setUp()
# We have a SentencePiece fixture for testing # We have a SentencePiece fixture for testing
tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB) tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, unk_token="<unk>")
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_input_output_texts(self, tokenizer): def get_input_output_texts(self, tokenizer):
...@@ -55,7 +55,6 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -55,7 +55,6 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_get_vocab(self): def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys()) vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<pad>") self.assertEqual(vocab_keys[0], "<pad>")
self.assertEqual(vocab_keys[1], "<unk>") self.assertEqual(vocab_keys[1], "<unk>")
self.assertEqual(vocab_keys[-1], "[PAD]") self.assertEqual(vocab_keys[-1], "[PAD]")
...@@ -80,6 +79,14 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -80,6 +79,14 @@ class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(rust_tokens, tokens_target) self.assertListEqual(rust_tokens, tokens_target)
@unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.")
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
pass
@unittest.skip("There is an inconsistency between slow and fast tokenizer due to a bug in the fast one.")
def test_sentencepiece_tokenize_and_decode(self):
pass
def test_split_by_punct(self): def test_split_by_punct(self):
# fmt: off # fmt: off
sequence = "I was born in 92000, and this is falsé." sequence = "I was born in 92000, and this is falsé."
......
...@@ -1946,3 +1946,11 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1946,3 +1946,11 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't support another framework than PyTorch") @unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self): def test_np_encode_plus_sent_to_model(self):
pass pass
@unittest.skip("Doesn't use SentencePiece")
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
pass
@unittest.skip("Doesn't use SentencePiece")
def test_sentencepiece_tokenize_and_decode(self):
pass
...@@ -385,6 +385,33 @@ class TokenizerTesterMixin: ...@@ -385,6 +385,33 @@ class TokenizerTesterMixin:
self.assertEqual(reverse_text, text) self.assertEqual(reverse_text, text)
special_tokens = tokenizer.all_special_tokens
special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens)
for special_token in special_tokens:
self.assertIn(special_token, special_tokens_string)
if self.test_rust_tokenizer:
rust_tokenizer = self.get_rust_tokenizer()
special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens)
self.assertEqual(special_tokens_string, special_tokens_string_rust)
def test_sentencepiece_tokenize_and_decode(self):
if not self.test_sentencepiece:
return
text = "This is text to test the tokenizer."
if self.test_rust_tokenizer:
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
slow_ids = tokenizer(text).input_ids
fast_ids = rust_tokenizer(text).input_ids
self.assertEqual(slow_ids, fast_ids)
slow_decoded = tokenizer.decode(slow_ids)
fast_decoded = rust_tokenizer.decode(slow_ids)
self.assertEqual(slow_decoded, fast_decoded)
def test_subword_regularization_tokenizer(self) -> None: def test_subword_regularization_tokenizer(self) -> None:
if not self.test_sentencepiece: if not self.test_sentencepiece:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment