Unverified Commit 7ac91107 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Add more tests on tokenizers serialization - fix bugs (#5056)

* update tests for fast tokenizers + fix small bug in saving/loading

* better tests on serialization

* fixing serialization

* comment cleanup
parent 0148c262
...@@ -20,7 +20,7 @@ import itertools ...@@ -20,7 +20,7 @@ import itertools
import logging import logging
import re import re
import unicodedata import unicodedata
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from .file_utils import add_end_docstrings from .file_utils import add_end_docstrings
from .tokenization_utils_base import ( from .tokenization_utils_base import (
...@@ -155,10 +155,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -155,10 +155,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# Added tokens
self.added_tokens_encoder = {} # Added tokens - We store this for both slow and fast tokenizers
self.unique_added_tokens_encoder = [] # until the serialization of Fast tokenizers is updated
self.added_tokens_decoder = {} self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
...@@ -173,11 +175,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -173,11 +175,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError() raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]:
return self.added_tokens_encoder
def __len__(self): def __len__(self):
""" Size of the full vocabulary with the added tokens """ """ Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder) return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens: Union[str, List[str]], special_token=False) -> int: def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
""" """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary. vocabulary, they are added to it with indices starting from length of the current vocabulary.
...@@ -199,16 +204,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -199,16 +204,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
print('We have added', num_added_toks, 'tokens') print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
""" """
if not new_tokens: new_tokens = [str(tok) for tok in new_tokens]
return 0
if not isinstance(new_tokens, list):
new_tokens = [new_tokens]
tokens_to_add = [] tokens_to_add = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, (str, AddedToken)) assert isinstance(token, str)
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens: if not special_tokens and self.init_kwargs.get("do_lower_case", False):
token = token.lower() token = token.lower()
if ( if (
token != self.unk_token token != self.unk_token
...@@ -222,11 +223,15 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -222,11 +223,15 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = list(
set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
)
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
if special_tokens:
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
else:
# Or on the newly added tokens
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
return len(tokens_to_add) return len(tokens_to_add)
def num_special_tokens_to_add(self, pair=False): def num_special_tokens_to_add(self, pair=False):
...@@ -340,7 +345,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -340,7 +345,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
for tok in tok_list: for tok in tok_list:
tokenized_text = [] tokenized_text = []
for sub_text in text_list: for sub_text in text_list:
if sub_text not in self.unique_added_tokens_encoder: if sub_text not in self.unique_no_split_tokens:
tokenized_text += split_on_token(tok, sub_text) tokenized_text += split_on_token(tok, sub_text)
else: else:
tokenized_text += [sub_text] tokenized_text += [sub_text]
...@@ -349,14 +354,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -349,14 +354,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return list( return list(
itertools.chain.from_iterable( itertools.chain.from_iterable(
( (
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token] self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
for token in tokenized_text for token in tokenized_text
) )
) )
) )
added_tokens = self.unique_added_tokens_encoder no_split_token = self.unique_no_split_tokens
tokenized_text = split_on_tokens(added_tokens, text) tokenized_text = split_on_tokens(no_split_token, text)
return tokenized_text return tokenized_text
def _tokenize(self, text, **kwargs): def _tokenize(self, text, **kwargs):
......
...@@ -62,9 +62,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]] ...@@ -62,9 +62,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]] EncodedInputPair = Tuple[List[int], List[int]]
# Slow tokenizers used to be saved in three separated files
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json" ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json" FULL_TOKENIZER_FILE = "tokenizer.json"
...@@ -589,10 +592,14 @@ class SpecialTokensMixin: ...@@ -589,10 +592,14 @@ class SpecialTokensMixin:
self._additional_special_tokens = [] self._additional_special_tokens = []
self.verbose = verbose self.verbose = verbose
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necesssary for serialization/de-serialization
# TODO clean this up at some point (probably by sitching to fast tokenizers)
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
setattr(self, key, value)
elif isinstance(value, (str, AddedToken)): elif isinstance(value, (str, AddedToken)):
setattr(self, key, value) setattr(self, key, value)
else: else:
...@@ -607,7 +614,7 @@ class SpecialTokensMixin: ...@@ -607,7 +614,7 @@ class SpecialTokensMixin:
Return: Return:
Number of tokens added in the vocaulary during the operation. Number of tokens added in the vocaulary during the operation.
""" """
return self.add_tokens(self.all_special_tokens_extended, special_token=True) return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict): def add_special_tokens(self, special_tokens_dict):
""" """
...@@ -652,22 +659,56 @@ class SpecialTokensMixin: ...@@ -652,22 +659,56 @@ class SpecialTokensMixin:
added_tokens = 0 added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if self.verbose: if self.verbose:
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value) setattr(self, key, value)
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
added_tokens += self.add_tokens(value, special_token=True) added_tokens += self.add_tokens(value, special_tokens=True)
else: else:
assert isinstance(value, str) assert isinstance(value, str)
added_tokens += self.add_tokens([value], special_token=True) added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens return added_tokens
def add_tokens(self, value, special_token=False): def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int:
""" To be overriden by derived class to add a token in the vocabulary. """ """
pass Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...).
special_token: can be used to specify if the token is a special token. This mostly change the normalization
behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance)
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if not new_tokens:
return 0
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
return self._add_tokens(new_tokens, special_tokens=special_tokens)
@property @property
def bos_token(self): def bos_token(self):
...@@ -964,11 +1005,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -964,11 +1005,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
padding_side: str = "right" padding_side: str = "right"
def __init__(self, model_max_length=None, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = kwargs
# For backward compatibility we fallback to set model_max_length from max_len if provided # For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None) model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
...@@ -979,9 +1022,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -979,9 +1022,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) super().__init__(**kwargs)
self.init_inputs = ()
self.init_kwargs = {}
@property @property
def max_len(self) -> int: def max_len(self) -> int:
...@@ -1125,8 +1166,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1125,8 +1166,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"added_tokens_file": ADDED_TOKENS_FILE, "added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
"full_tokenizer_file": FULL_TOKENIZER_FILE,
} }
# Look for the tokenizer main vocabulary files + the additional tokens files # Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
...@@ -1215,18 +1257,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1215,18 +1257,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Merge resolved_vocab_files arguments in init_kwargs. # Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
for args_name, file_path in resolved_vocab_files.items(): for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs: if args_name not in init_kwargs:
init_kwargs[args_name] = file_path init_kwargs[args_name] = file_path
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
if key not in init_kwargs:
init_kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
try: try:
...@@ -1241,20 +1274,39 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1241,20 +1274,39 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
tokenizer.init_inputs = init_inputs tokenizer.init_inputs = init_inputs
tokenizer.init_kwargs = init_kwargs tokenizer.init_kwargs = init_kwargs
# update unique_added_tokens_encoder with special tokens for correct tokenization # If there is a complementary special token map, load it
if hasattr(tokenizer, "unique_added_tokens_encoder"): special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.all_special_tokens) if special_tokens_map_file is not None:
tokenizer.unique_added_tokens_encoder = list(union) with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
setattr(tokenizer, key, value)
# Add supplementary tokens. # Add supplementary tokens.
special_tokens = tokenizer.all_special_tokens
if added_tokens_file is not None: if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle) added_tok_encoder = json.load(added_tokens_handle)
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder) # Sort added tokens by index
tokenizer.added_tokens_decoder.update(added_tok_decoder) added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.added_tokens_encoder.keys())
tokenizer.unique_added_tokens_encoder = list(union) for token, index in added_tok_encoder_sorted:
assert index == len(tokenizer), (
f"Non-consecutive added token '{token}' found. "
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
)
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
# Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning(
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained."
)
return tokenizer return tokenizer
...@@ -1296,9 +1348,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1296,9 +1348,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
write_dict[key] = value write_dict[key] = value
f.write(json.dumps(write_dict, ensure_ascii=False)) f.write(json.dumps(write_dict, ensure_ascii=False))
if hasattr(self, "added_tokens_encoder") and len(self.added_tokens_encoder) > 0: added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f: with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) out_str = json.dumps(added_vocab, ensure_ascii=False)
f.write(out_str) f.write(out_str)
vocab_files = self.save_vocabulary(save_directory) vocab_files = self.save_vocabulary(save_directory)
......
...@@ -123,6 +123,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -123,6 +123,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> Dict[str, int]:
return self._tokenizer.get_vocab(with_added_tokens=True) return self._tokenizer.get_vocab(with_added_tokens=True)
def get_added_vocab(self) -> Dict[str, int]:
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
added_vocab = dict((tok, index) for tok, index in full_vocab.items() if tok not in base_vocab)
return added_vocab
def __len__(self) -> int: def __len__(self) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=True) return self._tokenizer.get_vocab_size(with_added_tokens=True)
...@@ -206,37 +212,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -206,37 +212,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str: def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_token=False) -> int: def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
""" if special_tokens:
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...).
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
if special_token:
return self._tokenizer.add_special_tokens(new_tokens) return self._tokenizer.add_special_tokens(new_tokens)
return self._tokenizer.add_tokens(new_tokens) return self._tokenizer.add_tokens(new_tokens)
......
...@@ -20,10 +20,10 @@ import re ...@@ -20,10 +20,10 @@ import re
import shutil import shutil
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from tests.utils import require_tf, require_torch from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -93,7 +93,7 @@ class TokenizerTesterMixin: ...@@ -93,7 +93,7 @@ class TokenizerTesterMixin:
output_ids = tokenizer.encode(output_txt, add_special_tokens=False) output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids return output_txt, output_ids
def get_tokenizers(self, fast=True, **kwargs) -> PreTrainedTokenizer: def get_tokenizers(self, fast=True, **kwargs) -> List[PreTrainedTokenizerBase]:
if fast and self.test_rust_tokenizer: if fast and self.test_rust_tokenizer:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
return [self.get_tokenizer(**kwargs)] return [self.get_tokenizer(**kwargs)]
...@@ -101,7 +101,7 @@ class TokenizerTesterMixin: ...@@ -101,7 +101,7 @@ class TokenizerTesterMixin:
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
raise NotImplementedError raise NotImplementedError
# def get_input_output_texts(self) -> Tuple[str, str]: # def get_input_output_texts(self) -> Tuple[str, str]:
...@@ -156,28 +156,62 @@ class TokenizerTesterMixin: ...@@ -156,28 +156,62 @@ class TokenizerTesterMixin:
def test_save_and_load_tokenizer(self): def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works # safety check on max_len default value so we are sure the test works
tokenizers = self.get_tokenizers(fast=False) tokenizers = self.get_tokenizers()
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
self.assertNotEqual(tokenizer.max_len, 42) self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test # Now let's start the test
tokenizers = self.get_tokenizers(fast=False, model_max_length=42) tokenizers = self.get_tokenizers()
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
sample_text = "He is very happy, UNwant\u00E9d,running" # Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False) before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
shutil.rmtree(tmpdirname)
tokenizer.save_pretrained(self.tmpdirname) # Now let's start the test
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname) tokenizers = self.get_tokenizers(model_max_length=42)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
tokenizer.add_tokens(["bim", "bambam"])
additional_special_tokens = tokenizer.additional_special_tokens
additional_special_tokens.append("new_additional_special_token")
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False) after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertIn("bim", after_vocab)
self.assertIn("bambam", after_vocab)
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
self.assertEqual(after_tokenizer.model_max_length, 42)
self.assertEqual(tokenizer.model_max_length, 42) tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, model_max_length=43)
self.assertEqual(tokenizer.model_max_length, 43) self.assertEqual(tokenizer.model_max_length, 43)
shutil.rmtree(tmpdirname)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this.""" """Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
...@@ -265,7 +299,10 @@ class TokenizerTesterMixin: ...@@ -265,7 +299,10 @@ class TokenizerTesterMixin:
all_size = len(tokenizer) all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0) self.assertNotEqual(vocab_size, 0)
self.assertEqual(vocab_size, all_size)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks) added_toks = tokenizer.add_tokens(new_toks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment