Unverified Commit 827d6d6e authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Cleanup fast tokenizers integration (#3706)



* First pass on utility classes and python tokenizers

* finishing cleanup pass

* style and quality

* Fix tests

* Updating following @mfuntowicz comment

* style and quality

* Fix Roberta

* fix batch_size/seq_length inBatchEncoding

* add alignement methods + tests

* Fix OpenAI and Transfo-XL tokenizers

* adding trim_offsets=True default for GPT2 et RoBERTa

* style and quality

* fix tests

* add_prefix_space in roberta

* bump up tokenizers to rc7

* style

* unfortunately tensorfow does like these - removing shape/seq_len for now

* Update src/transformers/tokenization_utils.py
Co-Authored-By: default avatarStefan Schweter <stefan@schweter.it>

* Adding doc and docstrings

* making flake8 happy
Co-authored-by: default avatarStefan Schweter <stefan@schweter.it>
parent 60a42ef1
......@@ -118,12 +118,6 @@ class T5Tokenizer(PreTrainedTokenizer):
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try:
import sentencepiece as spm
......
......@@ -101,13 +101,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
if never_split is None:
never_split = self.all_special_tokens
if special is None:
......@@ -410,6 +403,16 @@ class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "Fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library).
The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization).
Adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
"""
vocab_files_names = VOCAB_FILES_NAMES_FAST
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
......
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
"""Tokenization classes for python and fast tokenizers. Fast tokenizers are provided by HuggingFace's tokenizers library."""
import copy
import functools
......@@ -24,11 +24,12 @@ import os
import re
from collections import UserDict, defaultdict
from contextlib import contextmanager
from typing import List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from tokenizers import AddedToken, Encoding
from tokenizers.decoders import Decoder
from tokenizers.implementations import BaseTokenizer
from tokenizers import AddedToken as AddedTokenFast
from tokenizers import Encoding as EncodingFast
from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available
......@@ -44,12 +45,40 @@ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input
LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
# Define type aliases
# Define type aliases and NamedTuples
TextInput = str
TextPairInput = Tuple[str, str]
PreTokenizedInput = List[str]
EncodedInput = List[int]
TextInputPair = Tuple[str, str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
class CharSpan(NamedTuple):
""" Character span in the original string
Args:
start: index of the first character in the original string
end: index of the character following the last character in the original string
"""
start: int
end: int
class TokenSpan(NamedTuple):
""" Token span in an encoded string (list of tokens)
Args:
start: index of the first token in the span
end: index of the token following the last token in the span
"""
start: int
end: int
def flatten(x: Sequence):
......@@ -68,7 +97,7 @@ def flatten(x: Sequence):
@contextmanager
def truncate_and_pad(
tokenizer: BaseTokenizer,
tokenizer: BaseTokenizerFast,
max_length: int,
stride: int,
strategy: str,
......@@ -78,26 +107,23 @@ def truncate_and_pad(
pad_token_type_id: int,
pad_token: str,
):
"""
This contextmanager is in charge of defining the truncation and the padding strategies and then
restore the tokenizer settings afterwards.
This contextmanager assumes the provider tokenizer has no padding / truncation strategy
before the managed section. If your tokenizer set a padding / truncation strategy before,
then it will be reset to no padding/truncation when exiting the managed section.
""" This contextmanager is in charge of defining the truncation and the padding strategies for fast tokenizers
(provided by HuggingFace tokenizers library) and restore the tokenizer settings afterwards.
Args:
tokenizer (BaseTokenizer): The tokenizer which will be used
max_length (int): The maximum size of the sequence
stride (int): The stride to use when handling overflow
strategy (str): Overflowing logic to use
pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length
padding_side (str): "left" or "right" indicating the direction the output sequence will be padded
pad_token_id (int): The integer representation of the padding token to use
pad_token_type_id (int): The integer representation of the padding token type to use
pad_token (str): The string representation of the padding token to use
This contextmanager assumes the provider tokenizer has no padding / truncation strategy
before the managed section. If your tokenizer set a padding / truncation strategy before,
then it will be reset to no padding/truncation when exiting the managed section.
Returns:
Args:
tokenizer (BaseTokenizerFast): The tokenizer which will be used
max_length (int): The maximum size of the sequence
stride (int): The stride to use when handling overflow
strategy (str): Overflowing logic to use
pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length
padding_side (str): "left" or "right" indicating the direction the output sequence will be padded
pad_token_id (int): The integer representation of the padding token to use
pad_token_type_id (int): The integer representation of the padding token type to use
pad_token (str): The string representation of the padding token to use
"""
......@@ -124,6 +150,9 @@ def truncate_and_pad(
yield
# TODO(morgan, anthony): once we have a simple way to serialize tokenizers maybe store and restore the state afterward
# to avoid destructing the padding / truncation strategy as we do now.
if max_length is not None:
tokenizer.no_truncation()
......@@ -132,117 +161,311 @@ def truncate_and_pad(
class BatchEncoding(UserDict):
"""
Data structure derived from Dictionary holding all the required information to forward through
a model.
""" BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc).
This class is derived from a python Dictionary and can be used as a dictionnary.
In addition, this class expose utility methods to map from word/char space to token space.
Args:
data (:obj:`dict`): Dictionary of lists/arrays returned by the encode/batch_encode methods ('input_ids', 'attention_mask'...)
encoding (:obj:`EncodingFast`, :obj:`list(EncodingFast)`, `optional`, defaults to :obj:`None`):
If the tokenizer is a fast tokenizer which outputs additional informations like mapping from word/char space to token space
the `EncodingFast` instance or list of instance (for batches) hold these informations.
In addition, this structure expose utility methods to map from word/char space to token space.
"""
def __init__(self, data: dict, encoding: Optional[Union[Encoding, Sequence[Encoding]]] = None):
def __init__(self, data: Dict[str, Any], encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None):
super().__init__(data)
if isinstance(encoding, Encoding):
if isinstance(encoding, EncodingFast):
encoding = [encoding]
self._encodings = encoding
def __getitem__(self, item: Union[int, str]) -> Encoding:
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
If the key is an integer, get the EncodingFast for batch item with index `key`
"""
if isinstance(item, str):
return self.data[item]
elif self._encodings is not None:
return self._encodings[item]
else:
raise KeyError("int index is supported only on {} from a Rust tokenizer".format(type(self).__name__))
raise KeyError(
"Indexing with integers (to access backend Encoding for a given batch index) "
"is not available when using Python based tokenizers"
)
def __getattr__(self, item: str):
return self.data[item]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
# After this point:
# Extended properties and methods only available for fast (Rust-based) tokenizers
# provided by HuggingFace tokenizers library.
@property
def encodings(self) -> Optional[List[Encoding]]:
def encodings(self) -> Optional[List[EncodingFast]]:
"""
Return the list all encoding from the tokenization process
Returns: List[Encoding] or None if input was tokenized through Python tokenizer
Returns: List[EncodingFast] or None if input was tokenized through Python (i.e. not fast) tokenizer
"""
return self._encodings
def keys(self):
return self.data.keys()
def tokens(self, batch_index: int = 0) -> List[int]:
if not self._encodings:
raise ValueError("tokens() is not available when using Python based tokenizers")
return self._encodings[batch_index].tokens
def values(self):
return self.data.values()
def words(self, batch_index: int = 0) -> List[Optional[int]]:
if not self._encodings:
raise ValueError("words() is not available when using Python based tokenizers")
return self._encodings[batch_index].words
def items(self):
return self.data.items()
def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
""" Get the index of the word corresponding (i.e. comprising) to an encoded token
in a sequence of the batch.
Can be called as:
- self.token_to_word(token_index) if batch size is 1
- self.token_to_word(batch_index, token_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence,
this can be the index of the token in the sequence
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the token in the sequence.
Returns:
word_index (:obj:`int`):
index of the word in the input sequence.
def char_to_token_offsets(self, sentence: int, char: int) -> Tuple[int, int]:
"""
Find the Offsets of the token containing the character at the specified position
if not self._encodings:
raise ValueError("token_to_word() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if token_index < 0:
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan:
""" Get the encoded token span corresponding to a word in the sequence of the batch.
Token spans are returned as a TokenSpan NamedTuple with:
start: index of the first token
end: index of the token following the last token
Can be called as:
- self.word_to_tokens(word_index) if batch size is 1
- self.word_to_tokens(batch_index, word_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Args:
sentence: Index of the sentence relative to the batch provided to the tokenizer
char: Char index to get the relative token offsets
batch_or_word_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprises one sequence,
this can be the index of the word in the sequence
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the word in the sequence.
Returns:
tuple: (token start, token end)
token_span (:obj:`TokenSpan`):
Span of tokens in the encoded sequence.
TokenSpan are NamedTuple with:
start: index of the first token
end: index of the token following the last token
"""
if not self._encodings:
raise ValueError("char_to_token_offsets() is not available when using Python based tokenizers")
return self[sentence].char_to_token_offsets(char)
raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
if word_index is not None:
batch_index = batch_or_word_index
else:
batch_index = 0
word_index = batch_or_word_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if word_index < 0:
word_index = self._seq_len + word_index
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index)))
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
""" Get the character span corresponding to an encoded token in a sequence of the batch.
Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string associated to the token
end: index of the character following the last character in the original string associated to the token
Can be called as:
- self.token_to_chars(token_index) if batch size is 1
- self.token_to_chars(batch_index, token_index) if batch size is greater or equal to 1
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence,
this can be the index of the token in the sequence
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the token or tokens in the sequence.
Returns:
char_span (:obj:`CharSpan`):
Span of characters in the original string.
def char_to_token(self, sentence: int, char: int) -> int:
CharSpan are NamedTuple with:
start: index of the first character in the original string
end: index of the character following the last character in the original string
"""
Return the index of the token at position of the given char.
if not self._encodings:
raise ValueError("token_to_chars() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))
def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
""" Get the index of the token in the encoded output comprising a character
in the original string for a sequence of the batch.
Can be called as:
- self.char_to_token(char_index) if batch size is 1
- self.char_to_token(batch_index, char_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
char (int): Char index to get the relative token offsets
batch_or_char_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence,
this can be the index of the word in the sequence
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the word in the sequence.
Returns:
int: Integer referring to the position of the token in the returned set of tokens for the sentence
token_index (:obj:`int`):
Index of the token.
"""
if not self._encodings:
raise ValueError("char_to_token() is not available when using Python based tokenizers")
return self[sentence].char_to_token(char)
if char_index is not None:
batch_index = batch_or_char_index
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_token(char_index)
def char_to_word_offsets(self, sentence: int, char: int) -> Tuple[int, int]:
"""
Find the Offsets of the word containing the character at the specified position
def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan:
""" Get the character span in the original string corresponding to given word in a sequence
of the batch.
Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string
end: index of the character following the last character in the original string
Can be called as:
- self.word_to_chars(word_index) if batch size is 1
- self.word_to_chars(batch_index, word_index) if batch size is greater or equal to 1
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
char (int): Char index to get the relative token offsets
batch_or_word_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence,
this can be the index of the word in the sequence
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the word in the sequence.
Returns:
tuple: (word start, word end) representing the first and last characters of the word
char_span (:obj:`CharSpan` or :obj:`List[CharSpan]`):
Span(s) of the associated character or characters in the string.
CharSpan are NamedTuple with:
start: index of the first character associated to the token in the original string
end: index of the character following the last character associated to the token in the original string
"""
if not self._encodings:
raise ValueError("char_to_word_offsets() is not available when using Python based tokenizers")
return self[sentence].char_to_word_offsets(char)
raise ValueError("word_to_chars() is not available when using Python based tokenizers")
if word_index is not None:
batch_index = batch_or_word_index
else:
batch_index = 0
word_index = batch_or_word_index
return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index)))
def token_to_word_offsets(self, sentence: int, index: int) -> Optional[Tuple[int, int]]:
"""
Find the Offsets of the word containing the token at the given index
def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
""" Get the word in the original string corresponding to a character in the original string of
a sequence of the batch.
Can be called as:
- self.char_to_word(char_index) if batch size is 1
- self.char_to_word(batch_index, char_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
index (int): Index of the token to map to the original word offsets
batch_or_char_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprise one sequence,
this can be the index of the character in the orginal string.
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index
of the character in the orginal string.
Returns:
Optional[tuple]: (word start, word end) or None
token_index (:obj:`int` or :obj:`List[int]`):
Index or indices of the associated encoded token(s).
"""
if not self._encodings:
raise ValueError("token_to_word_offsets() is not available when using Python based tokenizers")
return self[sentence].token_to_word_offsets(index)
raise ValueError("char_to_word() is not available when using Python based tokenizers")
if char_index is not None:
batch_index = batch_or_char_index
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_word(char_index)
class SpecialTokensMixin:
""" SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and
handles specific behaviors related to special tokens. In particular, this class hold the
attributes which can be used to directly access to these special tokens in a
model-independant manner and allow to set and update the special tokens.
"""
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
......@@ -255,7 +478,6 @@ class SpecialTokensMixin:
]
def __init__(self, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
......@@ -270,13 +492,13 @@ class SpecialTokensMixin:
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
elif isinstance(value, AddedToken):
elif isinstance(value, AddedTokenFast):
setattr(self, key, str(value))
elif isinstance(value, str):
setattr(self, key, value)
else:
raise TypeError(
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
"special token {} has to be either str or AddedTokenFast but got: {}".format(key, type(value))
)
@property
......@@ -335,33 +557,49 @@ class SpecialTokensMixin:
logger.error("Using additional_special_tokens, but it is not set yet.")
return self._additional_special_tokens
def _maybe_update_backend(self, value):
""" To be overriden by derived class if a backend tokenizer has to be updated. """
pass
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._maybe_update_backend([value])
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._maybe_update_backend([value])
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._maybe_update_backend([value])
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._maybe_update_backend([value])
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._maybe_update_backend([value])
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._maybe_update_backend([value])
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._maybe_update_backend([value])
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._maybe_update_backend(value)
@property
def bos_token_id(self):
......@@ -441,50 +679,69 @@ class SpecialTokensMixin:
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class PreTrainedTokenizer(SpecialTokensMixin):
""" Base class for all tokenizers.
Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
Class attributes (overridden by derived classes):
Handle all the shared methods for tokenization and special tokens as well as methods
downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
This class also contain the added tokens in a unified way on top of all tokenizers so we don't
have to handle the specific vocabulary augmentation methods of the various underlying
dictionary structures (BPE, sentencepiece...).
Parameters:
- ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
- ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
- ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
- ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
Class attributes (overridden by derived classes):
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
`short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
``from_pretrained()`` method.
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
Args:
- ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
no associated max_length can be found in ``max_model_input_sizes``.
- ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
Should be selected between ['right', 'left']
- ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
model ("token_type_ids", "attention_mask"...).
- ``bos_token``: (`Optional`) string: a beginning of sentence token.
Will be associated to ``self.bos_token`` and ``self.bos_token_id``
- ``eos_token``: (`Optional`) string: an end of sentence token.
Will be associated to ``self.eos_token`` and ``self.eos_token_id``
- ``unk_token``: (`Optional`) string: an unknown token.
Will be associated to ``self.unk_token`` and ``self.unk_token_id``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
Will be associated to ``self.sep_token`` and ``self.sep_token_id``
- ``pad_token``: (`Optional`) string: a padding token.
Will be associated to ``self.pad_token`` and ``self.pad_token_id``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
leveraging self-attention along the full depth of the model).
Will be associated to ``self.cls_token`` and ``self.cls_token_id``
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
Adding all special tokens here ensure they won't be split by the tokenization process.
Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
model_input_names = ["token_type_ids", "attention_mask"]
vocab_files_names: Dict[str, str] = {}
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
max_model_input_sizes: Dict[str, int] = {}
model_input_names: List[str] = ["token_type_ids", "attention_mask"]
padding_side = "right"
padding_side: str = "right"
NO_PAD_TOKEN_FOR_BATCH_MSG = (
"No padding token is set for this model, therefore no batch can be made with uneven "
......@@ -507,18 +764,39 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def is_fast(self):
return False
@property
def max_len(self):
""" Kept here for backward compatibility.
Now renamed to `model_max_length` to avoid ambiguity.
"""
return self.model_max_length
@property
def max_len_single_sentence(self):
return self.model_max_length - self.num_special_tokens_to_add(pair=False)
@property
def max_len_sentences_pair(self):
return self.model_max_length - self.num_special_tokens_to_add(pair=True)
def get_vocab(self):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def __init__(self, max_len=None, **kwargs):
def __init__(self, model_max_length=None, **kwargs):
super().__init__(**kwargs)
self.max_len = max_len if max_len is not None else int(1e12)
# For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None)
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop("padding_side", self.padding_side)
assert self.padding_side in [
"right",
"left",
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
# Added tokens
......@@ -719,9 +997,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)):
init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
if model_max_length is not None and isinstance(model_max_length, (int, float)):
init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
......@@ -769,10 +1047,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
- special-tokens-to-class-attributes-mapping,
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This won't save modifications other than (added tokens and special token mapping) you may have
applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation).
Warning: This won't save modifications you may have applied to the tokenizer after the instantiation
(e.g. modifying tokenizer.do_lower_case after creation).
This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
This method make sure the full tokenizer can then be re-loaded using the
:func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
"""
if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory))
......@@ -807,7 +1086,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings.
Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
class method.
"""
raise NotImplementedError
......@@ -817,7 +1098,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
......@@ -939,14 +1221,14 @@ class PreTrainedTokenizer(SpecialTokensMixin):
Take care of added tokens.
text: The sequence to be encoded.
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
**kwargs: passed to the `prepare_for_tokenization` preprocessing method.
Args:
text (:obj:`string`): The sequence to be encoded.
**kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
"""
all_special_tokens = self.all_special_tokens
text = self.prepare_for_tokenization(text, **kwargs)
# TODO: should this be in the base class?
def lowercase_text(t):
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
......@@ -1014,8 +1296,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
raise NotImplementedError
def convert_tokens_to_ids(self, tokens):
""" Converts a single token, or a sequence of tokens, (str) in a single integer id
(resp. a sequence of ids), using the vocabulary.
""" Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
if tokens is None:
return None
......@@ -1041,8 +1323,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def encode(
self,
text: TextInput,
text_pair: Optional[TextInput] = None,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
......@@ -1057,11 +1339,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text (:obj:`str` or :obj:`List[str]`):
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`):
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
......@@ -1070,7 +1352,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
to their model.
max_length (:obj:`int`, `optional`, defaults to :obj:`None`):
If set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
If there are overflowing tokens, those will be added to the returned dictionary.
You can set it to the maximal input size of the model with `max_length = tokenizer.model_max_length`.
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
......@@ -1112,8 +1395,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def encode_plus(
self,
text: TextInput,
text_pair: Optional[TextInput] = None,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
......@@ -1133,11 +1416,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
text (:obj:`str` or :obj:`List[str]`):
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the later only for not-fast tokenizers)):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`):
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
......@@ -1147,6 +1430,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
max_length (:obj:`int`, `optional`, defaults to :obj:`None`):
If set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
You can set it to the maximal input size of the model with `max_length = tokenizer.model_max_length`.
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
......@@ -1188,8 +1472,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
Set to True to return special tokens mask information (default False).
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True to return (char_start, char_end) for each token (default False).
If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on
Rust-based tokenizers inheriting from PreTrainedTokenizerFast.
If using Python's tokenizer, this method will raise NotImplementedError.
This one is only available on fast tokenizers inheriting from PreTrainedTokenizerFast.
**kwargs: passed to the `self.tokenize()` method
Return:
......@@ -1201,7 +1485,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
attention_mask: list[int] if return_attention_mask is True (default)
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
and return_special_tokens_mask is True
}
With the fields:
......@@ -1240,7 +1525,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
# Throw an error if we can pad because there is no padding token
if pad_to_max_length and self.pad_token_id is None:
raise ValueError(
"Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy"
"Unable to set proper padding strategy as the tokenizer does not have a padding token. "
"In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
"or add a new pad token via the function add_special_tokens if you want to use a padding strategy"
)
first_ids = get_input_ids(text)
......@@ -1264,7 +1551,12 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair]
List[TextInput],
List[TextInputPair],
List[PreTokenizedInput],
List[PreTokenizedInputPair],
List[EncodedInput],
List[EncodedInputPair],
],
add_special_tokens: bool = True,
max_length: Optional[int] = None,
......@@ -1278,7 +1570,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False,
return_offsets_mapping: bool = False,
return_input_lengths: bool = False,
return_lengths: bool = False,
**kwargs
) -> BatchEncoding:
"""
......@@ -1286,7 +1578,10 @@ class PreTrainedTokenizer(SpecialTokensMixin):
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
batch_text_or_text_pairs (:obj:`List[str]` or :obj:`List[List[str]]`):
batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`,
:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`,
and for not-fast tokenizers, also:
:obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`):
Batch of sequences or pair of sequences to be encoded.
This can be a list of string/string-sequences/int-sequences or a list of pair of
string/string-sequences/int-sequence (see details in encode_plus)
......@@ -1339,8 +1634,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
Set to True to return (char_start, char_end) for each token (default False).
If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on
Rust-based tokenizers inheriting from PreTrainedTokenizerFast.
return_input_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`):
If set the resulting dictionary will include the length of each sample
return_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`):
If set the resulting dictionary will include the length of each encoded inputs
**kwargs: passed to the `self.tokenize()` method
Return:
......@@ -1434,12 +1729,10 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_masks,
return_lengths=return_lengths,
return_tensors=None, # We will convert the whole batch to tensors at the end
)
# Append the non-padded length to the output
if return_input_lengths:
outputs["input_len"] = len(outputs["input_ids"])
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
......@@ -1493,12 +1786,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
sequences if overflowing while taking into account the special tokens and manages a window stride for
overflowing tokens
return_lengths: bool = False,
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
......@@ -1508,8 +1800,8 @@ class PreTrainedTokenizer(SpecialTokensMixin):
max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
list of inputs.
stride: window stride for overflowing tokens. Can be useful to remove edge effect when using sequential
list of inputs. The overflowing token will contains a part of the previous window of tokens.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
......@@ -1524,10 +1816,12 @@ class PreTrainedTokenizer(SpecialTokensMixin):
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default True)
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default: set to model specifics).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
return_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`):
If set the resulting dictionary will include the length of each encoded inputs
Return:
A Dictionary of shape::
......@@ -1538,21 +1832,24 @@ class PreTrainedTokenizer(SpecialTokensMixin):
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
length: int if return_lengths is True
}
With the fields:
``input_ids``: list of token ids to be fed to a model
``token_type_ids``: list of token type ids to be fed to a model
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens.
- ``input_ids``: list of token ids to be fed to a model
- ``token_type_ids``: list of token type ids to be fed to a model
- ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
- ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
- ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens.
- ``length``: this is the length of ``input_ids``
"""
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
......@@ -1560,7 +1857,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
encoded_inputs = {}
# Handle max sequence length
# Truncation: Handle max sequence length
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
......@@ -1574,7 +1871,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Handle special_tokens
# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
......@@ -1582,46 +1879,43 @@ class PreTrainedTokenizer(SpecialTokensMixin):
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
# Build output dictionnary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if max_length and len(encoded_inputs["input_ids"]) > max_length:
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
if return_token_type_ids:
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
# Check lengths
assert max_length is None or len(encoded_inputs["input_ids"]) <= max_length
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length:
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len)
"indexing errors".format(len(ids), self.model_max_length)
)
# Padding
needs_to_be_padded = pad_to_max_length and (
max_length
and len(encoded_inputs["input_ids"]) < max_length
or max_length is None
and len(encoded_inputs["input_ids"]) < self.max_len
and self.max_len <= 10000
and len(encoded_inputs["input_ids"]) < self.model_max_length
and self.model_max_length <= LARGE_INTEGER
)
if pad_to_max_length and max_length is None and self.max_len > 10000:
if pad_to_max_length and max_length is None and self.model_max_length > LARGE_INTEGER:
logger.warning(
"Sequence can't be padded as no maximum length is specified and the model maximum length is too high."
)
if needs_to_be_padded:
difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"])
difference = (max_length if max_length is not None else self.model_max_length) - len(
encoded_inputs["input_ids"]
)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
......@@ -1642,14 +1936,16 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
elif return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
if return_lengths:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
# Prepare inputs as tensors if asked
# Prepare model inputs as tensors if asked
if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
......@@ -1676,14 +1972,27 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return BatchEncoding(encoded_inputs)
def prepare_for_tokenization(self, text, **kwargs):
def prepare_for_tokenization(self, text: str, **kwargs) -> str:
""" Performs any necessary transformations before tokenization """
return text
def truncate_sequences(
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
):
"""Truncates a sequence pair in place to the maximum length.
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: str = "longest_first",
stride: int = 0,
) -> Tuple[List[int], List[int], List[int]]:
""" Truncates a sequence pair in place to the maximum length.
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
number of tokens to remove using the truncation strategy
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences).
......@@ -1691,6 +2000,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
"""
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
......@@ -1724,12 +2036,12 @@ class PreTrainedTokenizer(SpecialTokensMixin):
)
return (ids, pair_ids, overflowing_tokens)
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
if token_ids_1 is None:
return len(token_ids_0) * [0]
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
......@@ -1741,7 +2053,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return token_ids_0
return token_ids_0 + token_ids_1
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
......@@ -1758,7 +2072,9 @@ class PreTrainedTokenizer(SpecialTokensMixin):
"""
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[int, List[int]]:
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
......@@ -1781,17 +2097,19 @@ class PreTrainedTokenizer(SpecialTokensMixin):
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index):
def _convert_id_to_token(self, index: int) -> str:
raise NotImplementedError
def convert_tokens_to_string(self, tokens):
def convert_tokens_to_string(self, tokens: List[str]) -> str:
""" Converts a sequence of tokens (string) in a single string.
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time.
"""
return " ".join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> str:
"""
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
......@@ -1830,7 +2148,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return text
@staticmethod
def clean_up_tokenization(out_string):
def clean_up_tokenization(out_string: str) -> str:
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
"""
out_string = (
......@@ -1850,28 +2168,79 @@ class PreTrainedTokenizer(SpecialTokensMixin):
class PreTrainedTokenizerFast(PreTrainedTokenizer):
""" Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
Inherit from PreTrainedTokenizer.
model_input_names = ["token_type_ids", "attention_mask"]
Handle all the shared methods for tokenization and special tokens as well as methods
downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
def __init__(self, tokenizer: BaseTokenizer, **kwargs):
if tokenizer is None:
raise ValueError("Provided tokenizer cannot be None")
self._tokenizer = tokenizer
This class also contain the added tokens in a unified way on top of all tokenizers so we don't
have to handle the specific vocabulary augmentation methods of the various underlying
dictionary structures (BPE, sentencepiece...).
Class attributes (overridden by derived classes):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
`short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
``from_pretrained()`` method.
Args:
- ``tokenizer`` (`BaseTokenizerFast`): A Fast tokenizer from the HuggingFace tokenizer library (in low level Rust language)
- ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
no associated max_length can be found in ``max_model_input_sizes``.
- ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
Should be selected between ['right', 'left']
- ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
model ("token_type_ids", "attention_mask"...).
- ``bos_token``: (`Optional`) string: a beginning of sentence token.
Will be associated to ``self.bos_token`` and ``self.bos_token_id``
- ``eos_token``: (`Optional`) string: an end of sentence token.
Will be associated to ``self.eos_token`` and ``self.eos_token_id``
- ``unk_token``: (`Optional`) string: an unknown token.
Will be associated to ``self.unk_token`` and ``self.unk_token_id``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
Will be associated to ``self.sep_token`` and ``self.sep_token_id``
- ``pad_token``: (`Optional`) string: a padding token.
Will be associated to ``self.pad_token`` and ``self.pad_token_id``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
leveraging self-attention along the full depth of the model).
Will be associated to ``self.cls_token`` and ``self.cls_token_id``
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
Adding all special tokens here ensure they won't be split by the tokenization process.
Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
"""
def __init__(self, tokenizer: BaseTokenizerFast, **kwargs):
if not isinstance(tokenizer, BaseTokenizerFast):
raise ValueError(
"Tokenizer should be an instance of a Tokenizer " "provided by HuggingFace tokenizers library."
)
self._tokenizer: BaseTokenizerFast = tokenizer
# Initialize all the rest of the kwargs
super().__init__(**kwargs)
self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add(
False
) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add(
True
) # take into account special tokens
@property
def tokenizer(self) -> BaseTokenizer:
def backend_tokenizer(self) -> BaseTokenizerFast:
return self._tokenizer
@property
def decoder(self) -> Decoder:
def decoder(self) -> DecoderFast:
return self._tokenizer._tokenizer.decoder
@property
......@@ -1885,56 +2254,30 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def __len__(self) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=True)
@PreTrainedTokenizer.bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._tokenizer.add_special_tokens([self._bos_token])
@PreTrainedTokenizer.eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._tokenizer.add_special_tokens([self._eos_token])
@PreTrainedTokenizer.unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._tokenizer.add_special_tokens([self._unk_token])
@PreTrainedTokenizer.sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._tokenizer.add_special_tokens([self._sep_token])
@PreTrainedTokenizer.pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._tokenizer.add_special_tokens([self._pad_token])
@PreTrainedTokenizer.cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._tokenizer.add_special_tokens([self._cls_token])
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._tokenizer.add_special_tokens([self._mask_token])
@PreTrainedTokenizer.additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._tokenizer.add_special_tokens(self.all_special_tokens)
def _maybe_update_backend(self, value):
""" Update the backend fast tokenizer.
Override method from base class SpecialTokensMixin """
self._tokenizer.add_special_tokens(value)
def _convert_encoding(
self,
encoding,
return_tensors=None,
return_token_type_ids=None,
return_attention_mask=None,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_offsets_mapping=False,
):
encoding: EncodingFast,
return_tensors: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
) -> Dict[str, Any]:
""" Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
Overflowing tokens are converted to additional examples (like batches) so the output values of
the dict are lists (overflows) of lists (tokens).
If return_tensors is not None, these lists of lists are converted to 2-D tensors
for input_ids, token_type_ids and attention_mask.
Output shape: (overflows, sequence length)
"""
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
......@@ -1958,75 +2301,86 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
if return_offsets_mapping:
encoding_dict["offset_mapping"].append(e.offsets)
# Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available():
encoding_dict["input_ids"] = tf.constant(encoding_dict["input_ids"])
if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"] = tf.constant(encoding_dict["token_type_ids"])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant(encoding_dict["attention_mask"])
elif return_tensors == "pt" and is_torch_available():
encoding_dict["input_ids"] = torch.tensor(encoding_dict["input_ids"])
if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"] = torch.tensor(encoding_dict["token_type_ids"])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = torch.tensor(encoding_dict["attention_mask"])
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
if return_tensors is not None:
for key, value in encoding_dict.items():
if return_tensors == "tf" and is_tf_available():
encoding_dict[key] = tf.constant(value)
elif return_tensors == "pt" and is_torch_available():
encoding_dict[key] = torch.tensor(value)
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, "
"PyTorch or TensorFlow is not available.".format(return_tensors)
)
return encoding_dict
def _convert_token_to_id_with_added_voc(self, token):
id = self._tokenizer.token_to_id(token)
if id is None:
def _convert_token_to_id_with_added_voc(self, token: int) -> str:
index = self._tokenizer.token_to_id(token)
if index is None:
return self.unk_token_id
return id
return index
def _convert_id_to_token(self, index: int) -> str:
def _convert_id_to_token(self, index: int) -> Optional[str]:
return self._tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens)
def add_tokens(self, new_tokens: List[Union[str, AddedToken]]) -> int:
def add_tokens(self, new_tokens: List[Union[str, AddedTokenFast]]) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or AddedTokenFast. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedTokenFast wrap a string token to let you personnalize it's behavior (Whether this token should only match against single word, whether this token should strip all potential whitespaces on the left side, Whether this token should strip all potential whitespaces on the right side...).
See details for AddedToken in HuggingFace tokenizers library.
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
return self._tokenizer.add_tokens(new_tokens)
def add_special_tokens(self, special_tokens_dict: dict) -> int:
added = super().add_special_tokens(special_tokens_dict)
# Map special tokens to class attributes (self.pad_token...)
num_added_tokens = super().add_special_tokens(special_tokens_dict)
# If the backend tokenizer the only specificities of special tokens are that
# - they will never be processed by the model, and
# - they will be removed while decoding.
# But they are not mapped to special attributes in the backend so we can just
# send a list.
tokens = flatten(special_tokens_dict.values())
self._tokenizer.add_special_tokens(tokens)
return added
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return token_ids_0
else:
return token_ids_0 + token_ids_1
return num_added_tokens
def num_special_tokens_to_add(self, pair: bool = False) -> int:
return self.tokenizer.num_special_tokens_to_add(pair)
return self._tokenizer.num_special_tokens_to_add(pair)
def tokenize(
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
) -> List[str]:
return self.tokenizer.encode(text, pair, add_special_tokens).tokens
return self._tokenizer.encode(text, pair, add_special_tokens).tokens
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair]
] = None,
List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
],
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
......@@ -2039,15 +2393,13 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
**kwargs
) -> BatchEncoding:
if batch_text_or_text_pairs is None:
if not isinstance(batch_text_or_text_pairs, list):
raise ValueError(
"None is not a valid input. "
"Should be a list/tuple of strings, "
"a list/tuple of integers, "
"A list of list of strings or tuple of strings."
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
# Needed if we have to return a tensor
......@@ -2070,11 +2422,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
pad_token=self._pad_token,
):
if not isinstance(batch_text_or_text_pairs, list):
raise TypeError(
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
# Check for the pretokenized path
if is_pretokenized:
encodings = []
......@@ -2089,35 +2436,27 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
"index {} is of type {}".format(i, type(sample))
)
# Convert to tuple for convenience
if isinstance(sample, list):
sample = (sample,)
# Test if we have a pair of sentences by checking the depth of nesting
is_pair = bool(len(sample) > 0 and isinstance(sample[0], (list, tuple)))
encodings_text = Encoding.merge(self._tokenizer.encode_batch(sample[0], False), True)
# Take care of the first sequence - we multi-thread over the words
encodings_text = EncodingFast.merge(
self._tokenizer.encode_batch(sample[0] if is_pair else sample, add_special_tokens=False),
growing_offsets=True,
)
# Check if we have pairs
if len(sample) == 2:
encodings_pair = Encoding.merge(
self._tokenizer.encode_batch([("", s) for s in sample[1]], False), True
# Take care of the second sequence if we have a pair
if is_pair:
encodings_pair = EncodingFast.merge(
self._tokenizer.encode_batch([("", s) for s in sample[1]], add_special_tokens=False),
growing_offsets=True,
)
# No pair, default to None
elif len(sample) == 1:
encodings_pair = None
# Something else is invalid
else:
raise ValueError(
"batch_encode_plus(..., is_pretokenized=True) requires batch_text_or_text_pairs "
"to be either List[List[str]] or List[Tuple[List[str], List[str]]] but sample at "
"index {} has too much dimensions (required 1 or 2, got: {}, type {})".format(
i, len(sample), type(sample)
)
)
encodings_pair = None
# Post-process
# Post-process - truncate/pad and add special tokens
encoding = self._tokenizer.post_process(encodings_text, encodings_pair, add_special_tokens)
encodings += [encoding]
encodings.append(encoding)
# Classical path with strings input
else:
......@@ -2138,6 +2477,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
)
# Convert encoding to dict
# `Tokens` has type: List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]]
# with nested dimensions corresponding to batch, overflows, sequence length
tokens = [
self._convert_encoding(
encoding=encoding,
......@@ -2154,6 +2495,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# Sanitize the output to have dict[list] from list[dict]
sanitized = {}
for key in tokens[0].keys():
# To List[List[List[int]]] of shape (batch, overflows, sequence length)
stack = [e for item in tokens for e in item[key]]
if return_tensors == "tf":
stack = tf.stack(stack, axis=0)
......@@ -2167,9 +2509,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# If returning overflowing tokens, we need to return a mapping
# from the batch idx to the original sample
if return_overflowing_tokens:
overflow_to_sample_mapping = [
i if len(item["input_ids"]) == 1 else [i] * len(item["input_ids"]) for i, item in enumerate(tokens)
]
overflow_to_sample_mapping = flatten([[i] * len(enc["input_ids"]) for i, enc in enumerate(tokens)])
sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping
return BatchEncoding(sanitized, encodings)
......@@ -2199,7 +2539,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# Encode through encode_batch with sequence of only one word which will be merged after hand
encoding = self._tokenizer.encode_batch(text, add_special_tokens=False)
encoding = Encoding.merge(encoding, True)
encoding = EncodingFast.merge(encoding, growing_offsets=True)
# Let's do the same for pairs if provided
if isinstance(text_pair, list):
......@@ -2207,7 +2547,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
encoding_pair = self._tokenizer.encode_batch(
[("", p) for p in text_pair], add_special_tokens=False
)
encoding_pair = Encoding.merge(encoding_pair, True)
encoding_pair = EncodingFast.merge(encoding_pair, growing_offsets=True)
elif text_pair is None:
encoding_pair = None
else:
......@@ -2268,8 +2608,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
):
text = self.tokenizer.decode(token_ids, skip_special_tokens)
) -> str:
text = self._tokenizer.decode(token_ids, skip_special_tokens)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
......
......@@ -629,9 +629,6 @@ class XLMTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
......
......@@ -128,8 +128,6 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
try:
import sentencepiece as spm
......
......@@ -138,8 +138,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
self._pad_token_type_id = 3
try:
......
......@@ -117,8 +117,6 @@ class XxxTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(
......
import logging
import unittest
from collections import namedtuple
from itertools import takewhile
......@@ -21,6 +22,10 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
......@@ -83,6 +88,85 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_add_tokens(tokenizer_r)
self.assert_offsets_mapping(tokenizer_r)
self.assert_add_special_tokens(tokenizer_r)
self.assert_alignement_methods(tokenizer_r)
def assert_alignement_methods(self, tokenizer_r):
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
text = " ".join(words)
batch_size = 3
encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
num_tokens = len(encoding["input_ids"])
last_word_index = len(words) - 1
last_token_index = num_tokens - 1
last_batch_index = batch_size - 1
last_char_index = len(text) - 1
# words, tokens
self.assertEqual(len(encoding.words(0)), num_tokens)
self.assertEqual(max(encoding.words(0)), last_word_index)
self.assertEqual(min(encoding.words(0)), 0)
self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
self.assertEqual(len(encoding.tokens(0)), num_tokens)
# Assert token_to_word
self.assertEqual(encoding.token_to_word(0), 0)
self.assertEqual(encoding.token_to_word(0, 0), 0)
self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
# Assert word_to_tokens
self.assertEqual(encoding.word_to_tokens(0).start, 0)
self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)
# Assert token_to_chars
self.assertEqual(encoding.token_to_chars(0).start, 0)
self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)
# Assert char_to_token
self.assertEqual(encoding.char_to_token(0), 0)
self.assertEqual(encoding.char_to_token(0, 0), 0)
self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
# Assert char_to_word
self.assertEqual(encoding.char_to_word(0), 0)
self.assertEqual(encoding.char_to_word(0, 0), 0)
self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
# Assert word_to_chars
self.assertEqual(encoding.word_to_chars(0).start, 0)
self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
# Ensure basic input match
......@@ -306,7 +390,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
)
......@@ -316,7 +399,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
assert_batch_padded_input_match(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
[
("This is a simple input 1", "This is a simple input 2"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment