Unverified Commit 315f464b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

[tokenizers] Several small improvements and bug fixes (#5287)

* avoid recursion in id checks for fast tokenizers

* better typings and fix #5232

* align slow and fast tokenizers behaviors for Roberta and GPT2

* style and quality

* fix tests - improve typings
parent 24f46ea3
...@@ -23,7 +23,7 @@ from functools import lru_cache ...@@ -23,7 +23,7 @@ from functools import lru_cache
import regex as re import regex as re
from tokenizers import ByteLevelBPETokenizer from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast from .tokenization_utils_fast import PreTrainedTokenizerFast
...@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
add_prefix_space=False, add_prefix_space=False,
**kwargs **kwargs
): ):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
......
...@@ -21,7 +21,7 @@ from typing import List, Optional ...@@ -21,7 +21,7 @@ from typing import List, Optional
from tokenizers.processors import RobertaProcessing from tokenizers.processors import RobertaProcessing
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import AddedToken, PreTrainedTokenizer from .tokenization_utils import AddedToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer):
add_prefix_space=False, add_prefix_space=False,
**kwargs **kwargs
): ):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file=vocab_file, vocab_file=vocab_file,
merges_file=merges_file, merges_file=merges_file,
...@@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer):
**kwargs, **kwargs,
) )
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
...@@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=True, trim_offsets=True,
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
kwargs.setdefault("pad_token", pad_token) kwargs.setdefault("pad_token", pad_token)
kwargs.setdefault("sep_token", sep_token) kwargs.setdefault("sep_token", sep_token)
kwargs.setdefault("cls_token", cls_token) kwargs.setdefault("cls_token", cls_token)
...@@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
**kwargs, **kwargs,
) )
# This will add the necessary special tokens to the vocabulary if needed
self.sanitize_special_tokens()
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing( self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
sep=(sep_token, self.sep_token_id), sep=(sep_token, self.sep_token_id),
cls=(cls_token, self.cls_token_id), cls=(cls_token, self.cls_token_id),
...@@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=trim_offsets, trim_offsets=trim_offsets,
) )
self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None: if token_ids_1 is None:
......
...@@ -607,7 +607,7 @@ class SpecialTokensMixin: ...@@ -607,7 +607,7 @@ class SpecialTokensMixin:
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value)) "special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
) )
def sanitize_special_tokens(self): def sanitize_special_tokens(self) -> int:
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...) """ Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
are in the vocabulary. Add the missing ones to the vocabulary if needed. are in the vocabulary. Add the missing ones to the vocabulary if needed.
...@@ -616,7 +616,7 @@ class SpecialTokensMixin: ...@@ -616,7 +616,7 @@ class SpecialTokensMixin:
""" """
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict): def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
""" """
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added to class attributes. If special tokens are NOT in the vocabulary, they are added
...@@ -665,10 +665,14 @@ class SpecialTokensMixin: ...@@ -665,10 +665,14 @@ class SpecialTokensMixin:
setattr(self, key, value) setattr(self, key, value)
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
added_tokens += self.add_tokens(value, special_tokens=True) added_tokens += self.add_tokens(value, special_tokens=True)
else: else:
assert isinstance(value, str) assert isinstance(
value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance"
added_tokens += self.add_tokens([value], special_tokens=True) added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens return added_tokens
...@@ -809,26 +813,36 @@ class SpecialTokensMixin: ...@@ -809,26 +813,36 @@ class SpecialTokensMixin:
@property @property
def bos_token_id(self): def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token) return self.convert_tokens_to_ids(self.bos_token)
@property @property
def eos_token_id(self): def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token) return self.convert_tokens_to_ids(self.eos_token)
@property @property
def unk_token_id(self): def unk_token_id(self):
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token) return self.convert_tokens_to_ids(self.unk_token)
@property @property
def sep_token_id(self): def sep_token_id(self):
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token) return self.convert_tokens_to_ids(self.sep_token)
@property @property
def pad_token_id(self): def pad_token_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token) return self.convert_tokens_to_ids(self.pad_token)
@property @property
...@@ -839,11 +853,15 @@ class SpecialTokensMixin: ...@@ -839,11 +853,15 @@ class SpecialTokensMixin:
@property @property
def cls_token_id(self): def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None:
return None
return self.convert_tokens_to_ids(self.cls_token) return self.convert_tokens_to_ids(self.cls_token)
@property @property
def mask_token_id(self): def mask_token_id(self):
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token) return self.convert_tokens_to_ids(self.mask_token)
@property @property
......
...@@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return encoding_dict return encoding_dict
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
""" Converts a token string (or a sequence of tokens) in a single integer id """ Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary. (or a sequence of ids), using the vocabulary.
""" """
...@@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
ids.append(self._convert_token_to_id_with_added_voc(token)) ids.append(self._convert_token_to_id_with_added_voc(token))
return ids return ids
def _convert_token_to_id_with_added_voc(self, token: int) -> str: def _convert_token_to_id_with_added_voc(self, token: str) -> int:
index = self._tokenizer.token_to_id(token) index = self._tokenizer.token_to_id(token)
if index is None: if index is None:
return self.unk_token_id return self.unk_token_id
...@@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def _convert_id_to_token(self, index: int) -> Optional[str]: def _convert_id_to_token(self, index: int) -> Optional[str]:
return self._tokenizer.id_to_token(int(index)) return self._tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int: def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
if special_tokens: if special_tokens:
return self._tokenizer.add_special_tokens(new_tokens) return self._tokenizer.add_special_tokens(new_tokens)
...@@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[int, List[int]]: ) -> Union[str, List[str]]:
""" Converts a single index or a sequence of indices (integers) in a token " """ Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str), using the vocabulary and added tokens. (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
...@@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
tokens.append(self._tokenizer.id_to_token(index)) tokens.append(self._tokenizer.id_to_token(index))
return tokens return tokens
def tokenize( def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> List[str]:
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
) -> List[str]:
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
def set_truncation_and_padding( def set_truncation_and_padding(
......
...@@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
if tok_case.filter is None or ( if tok_case.filter is None or (
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name) tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
): ):
kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)): with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name) tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name) tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name) self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
self.fast_only(tokenizer_r) self.fast_only(tokenizer_r)
...@@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest): ...@@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
class RobertaFastTokenizerTest(CommonFastTokenizerTest): class RobertaFastTokenizerTest(CommonFastTokenizerTest):
TOKENIZERS_CLASSES = frozenset( TOKENIZERS_CLASSES = frozenset(
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)] [
Tokenizer(
"Roberta",
RobertaTokenizerFast,
RobertaTokenizer,
"vocab_file",
filter_roberta_detectors,
(("cls_token", "<s>"),),
)
]
) )
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p): def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
import os import os
import unittest import unittest
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow from .utils import slow
...@@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Testing spaces after special tokenss # Testing spaces after special tokenss
mask = "<mask>" mask = "<mask>"
tokenizer.add_special_tokens({"mask_token": mask}) tokenizer.add_special_tokens(
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
) # mask token has a left space
mask_ind = tokenizer.convert_tokens_to_ids(mask) mask_ind = tokenizer.convert_tokens_to_ids(mask)
sequence = "Encode <mask> sequence" sequence = "Encode <mask> sequence"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment