Unverified Commit 11fdde02 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Tokenizers API developments (#5103)



* Add return lengths

* make pad a bit more flexible so it can be used as collate_fn

* check all kwargs sent to encoding method are known

* fixing kwargs in encodings

* New AddedToken class in python

This class let you specify specifique tokenization behaviors for some special tokens. Used in particular for GPT2 and Roberta, to control how white spaces are stripped around special tokens.

* style and quality

* switched to hugginface tokenizers library for AddedTokens

* up to tokenizer 0.8.0-rc3 - update API to use AddedToken state

* style and quality

* do not raise an error on additional or unused kwargs for tokenize() but only a warning

* transfo-xl pretrained model requires torch

* Update src/transformers/tokenization_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 1ae132a0
...@@ -109,7 +109,7 @@ setup( ...@@ -109,7 +109,7 @@ setup(
packages=find_packages("src"), packages=find_packages("src"),
install_requires=[ install_requires=[
"numpy", "numpy",
"tokenizers == 0.8.0-rc1", "tokenizers == 0.8.0-rc3",
# dataclasses for Python versions that don't have it # dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'", "dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions # utilities from PyPA to e.g. compare versions
......
...@@ -23,7 +23,7 @@ from typing import List, Optional ...@@ -23,7 +23,7 @@ from typing import List, Optional
from tokenizers import BertWordPieceTokenizer from tokenizers import BertWordPieceTokenizer
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from .tokenization_utils_fast import PreTrainedTokenizerFast from .tokenization_utils_fast import PreTrainedTokenizerFast
...@@ -547,45 +547,6 @@ class WordpieceTokenizer(object): ...@@ -547,45 +547,6 @@ class WordpieceTokenizer(object):
return output_tokens return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
class BertTokenizerFast(PreTrainedTokenizerFast): class BertTokenizerFast(PreTrainedTokenizerFast):
r""" r"""
Constructs a "Fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). Constructs a "Fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library).
......
...@@ -146,6 +146,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -146,6 +146,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
unk_token="<|endoftext|>", unk_token="<|endoftext|>",
bos_token="<|endoftext|>", bos_token="<|endoftext|>",
eos_token="<|endoftext|>", eos_token="<|endoftext|>",
add_prefix_space=False,
**kwargs **kwargs
): ):
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
...@@ -161,6 +162,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -161,6 +162,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_merges = [tuple(merge.split()) for merge in bpe_merges] bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
self.add_prefix_space = add_prefix_space
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
...@@ -273,10 +275,11 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -273,10 +275,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return vocab_file, merge_file return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs): def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
if "add_prefix_space" in kwargs and kwargs["add_prefix_space"]: add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
return " " + text if is_pretokenized or add_prefix_space:
return text text = " " + text
return (text, kwargs)
class GPT2TokenizerFast(PreTrainedTokenizerFast): class GPT2TokenizerFast(PreTrainedTokenizerFast):
...@@ -354,7 +357,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): ...@@ -354,7 +357,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
is_pretokenized = kwargs.get("is_pretokenized", False) is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, ( assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False " f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs." "to use it with pretokenized inputs."
) )
...@@ -364,7 +367,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): ...@@ -364,7 +367,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
is_pretokenized = kwargs.get("is_pretokenized", False) is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, ( assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False " f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs." "to use it with pretokenized inputs."
) )
......
...@@ -18,11 +18,10 @@ ...@@ -18,11 +18,10 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from tokenizers import AddedToken
from tokenizers.processors import RobertaProcessing from tokenizers.processors import RobertaProcessing
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import AddedToken, PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -135,6 +134,7 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -135,6 +134,7 @@ class RobertaTokenizer(GPT2Tokenizer):
unk_token="<unk>", unk_token="<unk>",
pad_token="<pad>", pad_token="<pad>",
mask_token="<mask>", mask_token="<mask>",
add_prefix_space=False,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -148,9 +148,17 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -148,9 +148,17 @@ class RobertaTokenizer(GPT2Tokenizer):
cls_token=cls_token, cls_token=cls_token,
pad_token=pad_token, pad_token=pad_token,
mask_token=mask_token, mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs, **kwargs,
) )
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
...@@ -231,14 +239,11 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -231,14 +239,11 @@ class RobertaTokenizer(GPT2Tokenizer):
return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs): def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
if "add_prefix_space" in kwargs: add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
add_prefix_space = kwargs["add_prefix_space"] if (is_pretokenized or add_prefix_space) and text:
else:
add_prefix_space = add_special_tokens
if add_prefix_space and len(text) > 0 and not text[0].isspace():
text = " " + text text = " " + text
return text return (text, kwargs)
class RobertaTokenizerFast(GPT2TokenizerFast): class RobertaTokenizerFast(GPT2TokenizerFast):
...@@ -300,7 +305,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -300,7 +305,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
unk_token="<unk>", unk_token="<unk>",
pad_token="<pad>", pad_token="<pad>",
mask_token="<mask>", mask_token="<mask>",
add_prefix_space=True, add_prefix_space=False,
trim_offsets=True, trim_offsets=True,
**kwargs **kwargs
): ):
...@@ -327,15 +332,14 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -327,15 +332,14 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=trim_offsets, trim_offsets=trim_offsets,
) )
self.backend_tokenizer.add_special_tokens([kwargs["mask_token"]]) self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
@PreTrainedTokenizer.mask_token.setter @PreTrainedTokenizer.mask_token.setter
def mask_token(self, value): def mask_token(self, value):
if not isinstance(value, AddedToken): if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True) value = AddedToken(value, lstrip=True)
self._mask_token = str(value) self._mask_token = value
self._maybe_update_backend([value])
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
......
...@@ -355,10 +355,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -355,10 +355,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
return symbols return symbols
def prepare_for_tokenization(self, text, **kwargs): def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl # add spaces before punctuation symbols as should be done in transfo-xl
add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]: if add_space_before_punct_symbol:
text = self.punctuation_with_space_around_pattern.sub(r" ", text) text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text): elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces # searches until the first occurence of a punctuation symbol without surrounding spaces
...@@ -366,7 +366,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -366,7 +366,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token" "You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
) )
return text return (text, kwargs)
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
......
...@@ -19,12 +19,14 @@ ...@@ -19,12 +19,14 @@
import itertools import itertools
import logging import logging
import re import re
import unicodedata
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from .file_utils import add_end_docstrings from .file_utils import add_end_docstrings
from .tokenization_utils_base import ( from .tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING, ENCODE_KWARGS_DOCSTRING,
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
AddedToken,
BatchEncoding, BatchEncoding,
EncodedInput, EncodedInput,
EncodedInputPair, EncodedInputPair,
...@@ -42,6 +44,57 @@ from .tokenization_utils_base import ( ...@@ -42,6 +44,57 @@ from .tokenization_utils_base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def _is_end_of_word(text):
"""Checks whether the last character in text is one of a punctuation, control or whitespace character."""
last_char = text[-1]
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
def _is_start_of_word(text):
"""Checks whether the first character in text is one of a punctuation, control or whitespace character."""
first_char = text[0]
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
class PreTrainedTokenizer(PreTrainedTokenizerBase): class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Base class for all slow tokenizers. """ Base class for all slow tokenizers.
...@@ -104,7 +157,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -104,7 +157,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
super().__init__(**kwargs) super().__init__(**kwargs)
# Added tokens # Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = set() self.unique_added_tokens_encoder = []
self.added_tokens_decoder = {} self.added_tokens_decoder = {}
@property @property
...@@ -124,7 +177,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -124,7 +177,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Size of the full vocabulary with the added tokens """ """ Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder) return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens: Union[str, List[str]]) -> int: def add_tokens(self, new_tokens: Union[str, List[str]], special_token=False) -> int:
""" """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary. vocabulary, they are added to it with indices starting from length of the current vocabulary.
...@@ -154,7 +207,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -154,7 +207,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
tokens_to_add = [] tokens_to_add = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) assert isinstance(token, (str, AddedToken))
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens: if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
token = token.lower() token = token.lower()
if ( if (
...@@ -169,7 +222,9 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -169,7 +222,9 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens)) self.unique_added_tokens_encoder = list(
set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
)
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
return len(tokens_to_add) return len(tokens_to_add)
...@@ -204,24 +259,63 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -204,24 +259,63 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
text (:obj:`string`): The sequence to be encoded. text (:obj:`string`): The sequence to be encoded.
**kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method. **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
""" """
all_special_tokens = self.all_special_tokens # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
text = self.prepare_for_tokenization(text, **kwargs) all_special_tokens_extended = dict(
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
)
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
if kwargs:
logger.warning(f"Keyword arguments {kwargs} not recognized.")
# TODO: should this be in the base class? # TODO: should this be in the base class?
def lowercase_text(t): if self.init_kwargs.get("do_lower_case", False):
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
return re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), t) text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
if self.init_kwargs.get("do_lower_case", False):
text = lowercase_text(text)
def split_on_token(tok, text): def split_on_token(tok, text):
result = [] result = []
tok_extended = all_special_tokens_extended.get(tok, None)
split_text = text.split(tok) split_text = text.split(tok)
full_word = ""
for i, sub_text in enumerate(split_text): for i, sub_text in enumerate(split_text):
# AddedToken can control whitespace stripping around them.
# We use them for GPT2 and Roberta to have different behavior depending on the special token
# Cf. https://github.com/huggingface/transformers/pull/2778
# and https://github.com/huggingface/transformers/issues/3788
if isinstance(tok_extended, AddedToken):
if tok_extended.single_word:
# Try to avoid splitting on token
if (
i < len(split_text) - 1
and not _is_end_of_word(sub_text)
and not _is_start_of_word(split_text[i + 1])
):
# Don't extract the special token
full_word += sub_text + tok
elif full_word:
full_word += sub_text
result += [full_word]
full_word = ""
continue
# Strip white spaces on the right
if tok_extended.rstrip and i > 0:
# A bit counter-intuitive but we strip the left of the string
# since tok_extended.rstrip means the special token is eating all white spaces on its right
sub_text = sub_text.lstrip()
# Strip white spaces on the left
if tok_extended.lstrip and i < len(split_text) - 1:
sub_text = sub_text.rstrip() # Opposite here
else:
# We strip left and right by default
if i < len(split_text) - 1:
sub_text = sub_text.rstrip() sub_text = sub_text.rstrip()
if i > 0:
sub_text = sub_text.lstrip()
if i == 0 and not sub_text: if i == 0 and not sub_text:
result += [tok] result += [tok]
elif i == len(split_text) - 1: elif i == len(split_text) - 1:
...@@ -316,23 +410,17 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -316,23 +410,17 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True, verbose: bool = True,
**kwargs **kwargs
) -> BatchEncoding: ) -> BatchEncoding:
def get_input_ids(text): def get_input_ids(text):
if isinstance(text, str): if isinstance(text, str):
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens) return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized: if is_pretokenized:
tokens = list( tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
itertools.chain(
*(
self.tokenize(t, add_special_tokens=False, add_prefix_space=True, **kwargs)
for t in text
)
)
)
return self.convert_tokens_to_ids(tokens) return self.convert_tokens_to_ids(tokens)
else: else:
return self.convert_tokens_to_ids(text) return self.convert_tokens_to_ids(text)
...@@ -369,6 +457,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -369,6 +457,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_token_type_ids=return_token_type_ids, return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask, return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose, verbose=verbose,
) )
...@@ -390,28 +479,21 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -390,28 +479,21 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
is_pretokenized: bool = False, is_pretokenized: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None, return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False, return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_lengths: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
**kwargs **kwargs
) -> BatchEncoding: ) -> BatchEncoding:
def get_input_ids(text): def get_input_ids(text):
if isinstance(text, str): if isinstance(text, str):
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens) return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized: if is_pretokenized:
tokens = list( tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
itertools.chain(
*(
self.tokenize(t, add_special_tokens=False, add_prefix_space=True, **kwargs)
for t in text
)
)
)
return self.convert_tokens_to_ids(tokens) return self.convert_tokens_to_ids(tokens)
else: else:
return self.convert_tokens_to_ids(text) return self.convert_tokens_to_ids(text)
...@@ -449,11 +531,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -449,11 +531,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
max_length=max_length, max_length=max_length,
stride=stride, stride=stride,
return_attention_masks=return_attention_masks, return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids, return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_masks=return_special_tokens_masks, return_special_tokens_mask=return_special_tokens_mask,
return_lengths=return_lengths, return_length=return_length,
return_tensors=return_tensors, return_tensors=return_tensors,
verbose=verbose, verbose=verbose,
) )
...@@ -471,10 +553,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -471,10 +553,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
stride: int = 0, stride: int = 0,
return_tensors: Optional[str] = None, return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None, return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False, return_special_tokens_mask: bool = False,
return_lengths: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
) -> BatchEncoding: ) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
...@@ -507,11 +589,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -507,11 +589,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
max_length=max_length, max_length=max_length,
stride=stride, stride=stride,
return_attention_mask=return_attention_masks, return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids, return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_masks, return_special_tokens_mask=return_special_tokens_mask,
return_lengths=return_lengths, return_length=return_length,
return_tensors=None, # We will convert the whole batch to tensors at the end return_tensors=None, # We will convert the whole batch to tensors at the end
prepend_batch_axis=False, prepend_batch_axis=False,
verbose=verbose, verbose=verbose,
...@@ -542,7 +624,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -542,7 +624,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_attention_mask: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_lengths: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
) -> BatchEncoding: ) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
...@@ -615,7 +697,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -615,7 +697,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask,
) )
if return_lengths: if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"]) encoded_inputs["length"] = len(encoded_inputs["input_ids"])
batch_outputs = BatchEncoding( batch_outputs = BatchEncoding(
...@@ -624,9 +706,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -624,9 +706,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return batch_outputs return batch_outputs
def prepare_for_tokenization(self, text: str, **kwargs) -> str: def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
""" Performs any necessary transformations before tokenization """ """ Performs any necessary transformations before tokenization.
return text
This method should pop the arguments from kwargs and return kwargs as well.
We test kwargs at the end of the encoding process to be sure all the arguments have been used.
"""
return (text, kwargs)
def truncate_sequences( def truncate_sequences(
self, self,
......
This diff is collapsed.
...@@ -21,12 +21,12 @@ import os ...@@ -21,12 +21,12 @@ import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from tokenizers import AddedToken as AddedTokenFast
from tokenizers import Encoding as EncodingFast from tokenizers import Encoding as EncodingFast
from tokenizers.decoders import Decoder as DecoderFast from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
from .tokenization_utils_base import ( from .tokenization_utils_base import (
AddedToken,
BatchEncoding, BatchEncoding,
PaddingStrategy, PaddingStrategy,
PreTokenizedInput, PreTokenizedInput,
...@@ -134,11 +134,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -134,11 +134,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def decoder(self) -> DecoderFast: def decoder(self) -> DecoderFast:
return self._tokenizer._tokenizer.decoder return self._tokenizer._tokenizer.decoder
def _maybe_update_backend(self, value):
""" Update the backend fast tokenizer.
Override method from base class SpecialTokensMixin """
self._tokenizer.add_special_tokens(value)
def _convert_encoding( def _convert_encoding(
self, self,
encoding: EncodingFast, encoding: EncodingFast,
...@@ -147,6 +142,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -147,6 +142,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True, verbose: bool = True,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict. """ Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
...@@ -178,6 +174,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -178,6 +174,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping: if return_offsets_mapping:
encoding_dict["offset_mapping"].append(e.offsets) encoding_dict["offset_mapping"].append(e.offsets)
if return_length:
encoding_dict["length"].append(len(e.ids))
return encoding_dict return encoding_dict
...@@ -208,14 +206,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -208,14 +206,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str: def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def add_tokens(self, new_tokens: List[Union[str, AddedTokenFast]]) -> int: def add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_token=False) -> int:
""" """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary. vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args: Args:
new_tokens: string or list of string or :class:`~transformers.AddedTokenFast`. Each string is a token to add. new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedTokenFast wrap a string token to Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...). all potential whitespaces on the right side...).
...@@ -235,16 +233,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -235,16 +233,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
print('We have added', num_added_toks, 'tokens') print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
""" """
if isinstance(new_tokens, str): if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens] new_tokens = [new_tokens]
# TODO This should be done in tokenizers to be really clean.
# Removing for now if special_token:
# tokens = [] return self._tokenizer.add_special_tokens(new_tokens)
# for token in new_tokens:
# if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
# token = token.lower()
# if token not in tokens:
# tokens.append(token)
return self._tokenizer.add_tokens(new_tokens) return self._tokenizer.add_tokens(new_tokens)
def num_special_tokens_to_add(self, pair: bool = False) -> int: def num_special_tokens_to_add(self, pair: bool = False) -> int:
...@@ -330,7 +324,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -330,7 +324,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_lengths: bool = False, return_length: bool = False,
verbose: bool = True, verbose: bool = True,
**kwargs **kwargs
) -> BatchEncoding: ) -> BatchEncoding:
...@@ -340,6 +334,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -340,6 +334,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs)) "batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
) )
if kwargs:
raise ValueError(f"Keyword arguments {kwargs} not recognized.")
# Set the truncation and padding strategy and restore the initial configuration # Set the truncation and padding strategy and restore the initial configuration
self.set_truncation_and_padding( self.set_truncation_and_padding(
padding_strategy=padding_strategy, padding_strategy=padding_strategy,
...@@ -381,6 +378,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -381,6 +378,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask, return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose, verbose=verbose,
) )
for encoding in encodings for encoding in encodings
...@@ -419,6 +417,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -419,6 +417,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False, return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False, return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False, return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True, verbose: bool = True,
**kwargs **kwargs
) -> BatchEncoding: ) -> BatchEncoding:
...@@ -438,6 +437,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -438,6 +437,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask, return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping, return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
) )
......
...@@ -390,7 +390,7 @@ class TokenizerTesterMixin: ...@@ -390,7 +390,7 @@ class TokenizerTesterMixin:
seq_1 = "With these inputs." seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False) sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False) attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
# Method is implemented (e.g. not GPT-2) # Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2: if len(attached_sequences) != 2:
...@@ -416,7 +416,7 @@ class TokenizerTesterMixin: ...@@ -416,7 +416,7 @@ class TokenizerTesterMixin:
stride=stride, stride=stride,
truncation="longest_first", truncation="longest_first",
return_overflowing_tokens=True, return_overflowing_tokens=True,
add_prefix_space=False, # add_prefix_space=False,
) )
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
...@@ -468,7 +468,7 @@ class TokenizerTesterMixin: ...@@ -468,7 +468,7 @@ class TokenizerTesterMixin:
# We are not using the special tokens - a bit too hard to test all the tokenizers with this # We are not using the special tokens - a bit too hard to test all the tokenizers with this
# TODO try this again later # TODO try this again later
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=False, add_prefix_space=False) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=False) # , add_prefix_space=False)
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode( truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
seq_1, add_special_tokens=False seq_1, add_special_tokens=False
) )
...@@ -499,7 +499,7 @@ class TokenizerTesterMixin: ...@@ -499,7 +499,7 @@ class TokenizerTesterMixin:
stride=stride, stride=stride,
truncation="longest_first", truncation="longest_first",
return_overflowing_tokens=True, return_overflowing_tokens=True,
add_prefix_space=False, # add_prefix_space=False,
) )
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
...@@ -531,7 +531,7 @@ class TokenizerTesterMixin: ...@@ -531,7 +531,7 @@ class TokenizerTesterMixin:
stride=stride, stride=stride,
truncation=True, truncation=True,
return_overflowing_tokens=True, return_overflowing_tokens=True,
add_prefix_space=False, # add_prefix_space=False,
) )
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
...@@ -562,7 +562,7 @@ class TokenizerTesterMixin: ...@@ -562,7 +562,7 @@ class TokenizerTesterMixin:
stride=stride, stride=stride,
truncation="only_second", truncation="only_second",
return_overflowing_tokens=True, return_overflowing_tokens=True,
add_prefix_space=False, # add_prefix_space=False,
) )
# Overflowing tokens are handled quite differently in slow and fast tokenizers # Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast): if isinstance(tokenizer, PreTrainedTokenizerFast):
...@@ -638,7 +638,7 @@ class TokenizerTesterMixin: ...@@ -638,7 +638,7 @@ class TokenizerTesterMixin:
# Testing single inputs # Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus( encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False sequence_0, add_special_tokens=True, return_special_tokens_mask=True # , add_prefix_space=False
) )
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
...@@ -660,7 +660,7 @@ class TokenizerTesterMixin: ...@@ -660,7 +660,7 @@ class TokenizerTesterMixin:
sequence_1, sequence_1,
add_special_tokens=True, add_special_tokens=True,
return_special_tokens_mask=True, return_special_tokens_mask=True,
add_prefix_space=False, # add_prefix_space=False,
) )
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
...@@ -1042,7 +1042,7 @@ class TokenizerTesterMixin: ...@@ -1042,7 +1042,7 @@ class TokenizerTesterMixin:
def test_pretokenized_inputs(self): def test_pretokenized_inputs(self):
# Test when inputs are pretokenized # Test when inputs are pretokenized
tokenizers = self.get_tokenizers(do_lower_case=False, add_prefix_space=True) tokenizers = self.get_tokenizers(do_lower_case=False) # , add_prefix_space=True)
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
......
...@@ -63,6 +63,21 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -63,6 +63,21 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name) self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
self.fast_only(tokenizer_r) self.fast_only(tokenizer_r)
def test_pretokenized_tokenizers(self):
for tok_case in self.TOKENIZERS_CLASSES:
for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
# information available in Tokenizer (name, rust class, python class, vocab key name)
if tok_case.filter is None or (
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
):
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, add_prefix_space=True)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, add_prefix_space=True)
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name): def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
# Check is_fast is set correctly # Check is_fast is set correctly
self.assertFalse(tokenizer_p.is_fast) self.assertFalse(tokenizer_p.is_fast)
...@@ -75,7 +90,6 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -75,7 +90,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p) self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p) self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p) self.assert_padding(tokenizer_r, tokenizer_p)
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p) self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0 # TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p) # self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
...@@ -341,6 +355,14 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -341,6 +355,14 @@ class CommonFastTokenizerTest(unittest.TestCase):
"return_special_tokens_mask": True, "return_special_tokens_mask": True,
"return_offsets_mapping": False, # Not implemented in python tokenizers "return_offsets_mapping": False, # Not implemented in python tokenizers
} }
batch_kwargs = {
"is_pretokenized": True,
"return_token_type_ids": True,
"return_attention_mask": True, # we have an 's' here
"return_overflowing_tokens": False,
"return_special_tokens_mask": True, # we have an 's' here
"return_offsets_mapping": False, # Not implemented in python tokenizers
}
# Test encode_plus for pretokenized inputs # Test encode_plus for pretokenized inputs
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs) output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs) output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
...@@ -349,8 +371,8 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -349,8 +371,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
# Test batch_encode_plus for pretokenized inputs # Test batch_encode_plus for pretokenized inputs
input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair] input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
output_r = tokenizer_r.batch_encode_plus(input_batch, **kwargs) output_r = tokenizer_r.batch_encode_plus(input_batch, **batch_kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch, **kwargs) output_p = tokenizer_p.batch_encode_plus(input_batch, **batch_kwargs)
for key in output_p.keys(): for key in output_p.keys():
self.assertEqual(output_p[key], output_r[key]) self.assertEqual(output_p[key], output_r[key])
...@@ -370,8 +392,8 @@ class CommonFastTokenizerTest(unittest.TestCase): ...@@ -370,8 +392,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
pretokenized_input_simple + pretokenized_input_pair, pretokenized_input_simple + pretokenized_input_pair,
pretokenized_input_pair, pretokenized_input_pair,
] ]
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **kwargs) output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **batch_kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **kwargs) output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **batch_kwargs)
for key in output_p.keys(): for key in output_p.keys():
self.assertEqual(output_p[key], output_r[key]) self.assertEqual(output_p[key], output_r[key])
...@@ -756,8 +778,8 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest): ...@@ -756,8 +778,8 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True) tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
# Rust correctly handles the space before the mask while python doesnt # Rust correctly handles the space before the mask while python doesnt
self.assertSequenceEqual(tokens_r["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2]) self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(tokens_p["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2]) self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
# token_type_ids should put 0 everywhere # token_type_ids should put 0 everywhere
self.assertEquals(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) self.assertEquals(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
...@@ -768,9 +790,10 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest): ...@@ -768,9 +790,10 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]), sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
) )
# Rust should have 'Ġ' before <mask> which should be left as an entire token
tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"]) tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
self.assertSequenceEqual(tokens_r, ["<s>", "ĠA", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]) tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
self.assertSequenceEqual(tokens_r, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
self.assertSequenceEqual(tokens_p, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest): class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
...@@ -840,3 +863,7 @@ class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest): ...@@ -840,3 +863,7 @@ class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
@require_torch @require_torch
def test_all_tokenizers(self): def test_all_tokenizers(self):
super().test_all_tokenizers() super().test_all_tokenizers()
@require_torch
def test_pretokenized_tokenizers(self):
super().test_pretokenized_tokenizers()
...@@ -80,12 +80,12 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -80,12 +80,12 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text, add_prefix_space=True) tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def roberta_dict_integration_testing(self): def roberta_dict_integration_testing(self):
...@@ -124,7 +124,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -124,7 +124,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]] space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]]
# Testing encoder arguments # Testing encoder arguments
encoded = tokenizer.encode(sequence, add_special_tokens=False) encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=False)
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0] first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
self.assertNotEqual(first_char, space_encoding) self.assertNotEqual(first_char, space_encoding)
...@@ -135,7 +135,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -135,7 +135,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.add_special_tokens({"bos_token": "<s>"}) tokenizer.add_special_tokens({"bos_token": "<s>"})
encoded = tokenizer.encode(sequence, add_special_tokens=True) encoded = tokenizer.encode(sequence, add_special_tokens=True)
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0] first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertEqual(first_char, space_encoding) self.assertNotEqual(first_char, space_encoding)
# Testing spaces after special tokenss # Testing spaces after special tokenss
mask = "<mask>" mask = "<mask>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment