Unverified Commit 315f464b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

[tokenizers] Several small improvements and bug fixes (#5287)

* avoid recursion in id checks for fast tokenizers

* better typings and fix #5232

* align slow and fast tokenizers behaviors for Roberta and GPT2

* style and quality

* fix tests - improve typings
parent 24f46ea3
......@@ -23,7 +23,7 @@ from functools import lru_cache
import regex as re
from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
......@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
add_prefix_space=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
with open(vocab_file, encoding="utf-8") as vocab_handle:
......
......@@ -21,7 +21,7 @@ from typing import List, Optional
from tokenizers.processors import RobertaProcessing
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils import AddedToken
logger = logging.getLogger(__name__)
......@@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer):
add_prefix_space=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
......@@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer):
**kwargs,
)
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
......@@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=True,
**kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
kwargs.setdefault("pad_token", pad_token)
kwargs.setdefault("sep_token", sep_token)
kwargs.setdefault("cls_token", cls_token)
......@@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
**kwargs,
)
# This will add the necessary special tokens to the vocabulary if needed
self.sanitize_special_tokens()
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
sep=(sep_token, self.sep_token_id),
cls=(cls_token, self.cls_token_id),
......@@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=trim_offsets,
)
self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
......
......@@ -607,7 +607,7 @@ class SpecialTokensMixin:
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
)
def sanitize_special_tokens(self):
def sanitize_special_tokens(self) -> int:
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
are in the vocabulary. Add the missing ones to the vocabulary if needed.
......@@ -616,7 +616,7 @@ class SpecialTokensMixin:
"""
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict):
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added
......@@ -665,10 +665,14 @@ class SpecialTokensMixin:
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(value, str)
assert isinstance(
value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance"
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
......@@ -809,26 +813,36 @@ class SpecialTokensMixin:
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_id(self):
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self):
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token)
@property
......@@ -839,11 +853,15 @@ class SpecialTokensMixin:
@property
def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None:
return None
return self.convert_tokens_to_ids(self.cls_token)
@property
def mask_token_id(self):
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token)
@property
......
......@@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return encoding_dict
def convert_tokens_to_ids(self, tokens):
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
""" Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
......@@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids
def _convert_token_to_id_with_added_voc(self, token: int) -> str:
def _convert_token_to_id_with_added_voc(self, token: str) -> int:
index = self._tokenizer.token_to_id(token)
if index is None:
return self.unk_token_id
......@@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def _convert_id_to_token(self, index: int) -> Optional[str]:
return self._tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
if special_tokens:
return self._tokenizer.add_special_tokens(new_tokens)
......@@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[int, List[int]]:
) -> Union[str, List[str]]:
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
......@@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
tokens.append(self._tokenizer.id_to_token(index))
return tokens
def tokenize(
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
) -> List[str]:
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> List[str]:
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
def set_truncation_and_padding(
......
......@@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
if tok_case.filter is None or (
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
):
kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
self.fast_only(tokenizer_r)
......@@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
TOKENIZERS_CLASSES = frozenset(
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
[
Tokenizer(
"Roberta",
RobertaTokenizerFast,
RobertaTokenizer,
"vocab_file",
filter_roberta_detectors,
(("cls_token", "<s>"),),
)
]
)
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
......
......@@ -18,7 +18,7 @@ import json
import os
import unittest
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
......@@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Testing spaces after special tokenss
mask = "<mask>"
tokenizer.add_special_tokens({"mask_token": mask})
tokenizer.add_special_tokens(
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
) # mask token has a left space
mask_ind = tokenizer.convert_tokens_to_ids(mask)
sequence = "Encode <mask> sequence"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment