Unverified Commit dc42e770 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Easily train a new fast tokenizer from a given one (#12361)



* [WIP] Easily train a new fast tokenizer from a given one

* Fix test

* Roll out to other tokenizers and add tests

* Fix bug with unk id and add emoji to test

* Really use something different in test

* Implement special tokens map

* Map special tokens in the Transformers tokenizers

* Fix test

* Make test more robust

* Fix test for BPE

* More robust map and test

Co-authored-by SaulLu

* Test file

* Stronger tests
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>

* Map unk token for Wordpiece and address review comment

* Fix lowercase test and address review comment

* Fix all tests

* Simplify test

* Fix tests for realsies

* Easily train a new fast tokenizer from a given one - tackle the special tokens format (str or AddedToken) (#12420)

* Propose change in tests regarding lower case

* add new test for special tokens types

* put back the test part about decoding

* add feature: the AddedToken is re-build with the different mapped content

* Address review comment: simplify AddedToken building
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>

* Update src/transformers/tokenization_utils_fast.py
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>
parent b440b8d1
...@@ -104,7 +104,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast): ...@@ -104,7 +104,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
eos_token="</s>", eos_token="</s>",
unk_token="<unk>", unk_token="<unk>",
......
...@@ -117,7 +117,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast): ...@@ -117,7 +117,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
......
...@@ -124,7 +124,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast): ...@@ -124,7 +124,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=False, do_lower_case=False,
remove_space=True, remove_space=True,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
see tokenization_utils.py see tokenization_utils.py
""" """
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
...@@ -25,6 +24,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -25,6 +24,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from tokenizers import Encoding as EncodingFast from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast from tokenizers import Tokenizer as TokenizerFast
from tokenizers.decoders import Decoder as DecoderFast from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
from .convert_slow_tokenizer import convert_slow_tokenizer from .convert_slow_tokenizer import convert_slow_tokenizer
from .file_utils import PaddingStrategy, add_end_docstrings from .file_utils import PaddingStrategy, add_end_docstrings
...@@ -36,6 +36,7 @@ from .tokenization_utils_base import ( ...@@ -36,6 +36,7 @@ from .tokenization_utils_base import (
PreTokenizedInput, PreTokenizedInput,
PreTokenizedInputPair, PreTokenizedInputPair,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
SpecialTokensMixin,
TextInput, TextInput,
TextInputPair, TextInputPair,
TruncationStrategy, TruncationStrategy,
...@@ -60,6 +61,13 @@ INIT_TOKENIZER_DOCSTRING += """ ...@@ -60,6 +61,13 @@ INIT_TOKENIZER_DOCSTRING += """
from 🤗 tokenizers <../fast_tokenizers>` for more information. from 🤗 tokenizers <../fast_tokenizers>` for more information.
""" """
MODEL_TO_TRAINER_MAPPING = {
"BPE": BpeTrainer,
"Unigram": UnigramTrainer,
"WordLevel": WordLevelTrainer,
"WordPiece": WordPieceTrainer,
}
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizerFast(PreTrainedTokenizerBase): class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
...@@ -555,3 +563,162 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ...@@ -555,3 +563,162 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
file_names = file_names + (tokenizer_file,) file_names = file_names + (tokenizer_file,)
return file_names return file_names
def train_new_from_iterator(
self, text_iterator, vocab_size, new_special_tokens=None, special_tokens_map=None, **kwargs
):
"""
Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
as the current one.
Args:
text_iterator (generator of :obj:`List[str]`):
The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
if you have everything in memory.
vocab_size (obj:`int`):
The size of the vocabulary you want for your tokenizer.
new_special_tokens (list of :obj:`str` or :obj:`AddedToken`, `optional`):
A list of new special tokens to add to the tokenizer you are training.
special_tokens_map (:obj:`Dict[str, str]`, `optional`):
If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
token name to new special token name in this argument.
kwargs:
Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
Returns:
:class:`~transformers.PreTrainedTokenizerFast`: A new tokenizer of the same type as the original one,
trained on :obj:`text_iterator`.
"""
tokenizer_json = json.loads(self._tokenizer.to_str())
# Remove added tokens for now (uses IDs of tokens)
added_tokens = tokenizer_json.pop("added_tokens")
# Remove post processor for now (uses IDs of tokens)
post_processor = tokenizer_json.pop("post_processor")
unk_token = None
# Remove vocab
if tokenizer_json["model"]["type"] == "BPE":
tokenizer_json["model"]["vocab"] = {}
tokenizer_json["model"]["merges"] = []
elif tokenizer_json["model"]["type"] == "Unigram":
if tokenizer_json["model"]["unk_id"] is not None:
unk_id = tokenizer_json["model"]["unk_id"]
unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
if special_tokens_map is not None and unk_token in special_tokens_map:
unk_token = special_tokens_map[unk_token]
tokenizer_json["model"]["unk_id"] = 0
tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
tokenizer_json["model"]["vocab"] = {}
else:
raise ValueError(
f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
"only BPE, Unigram, WordLevel and WordPiece."
)
if (
special_tokens_map is not None
and "unk_token" in tokenizer_json["model"]
and tokenizer_json["model"]["unk_token"] in special_tokens_map
):
tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
# Get the special tokens from the current tokenizer if none are specified.
special_tokens = []
for added_token in added_tokens:
special = added_token.pop("special", None)
_ = added_token.pop("id", None)
if tokenizer_json["model"]["type"] != "Unigram" and not special:
continue
if special_tokens_map is not None and added_token["content"] in special_tokens_map:
added_token["content"] = special_tokens_map[added_token["content"]]
special_tokens.append(AddedToken(**added_token))
if new_special_tokens is not None:
special_tokens.extend(new_special_tokens)
# Trainer needs to know the end of word / continuing subword thingies in BPE
if (
tokenizer_json["model"]["type"] == "BPE"
and "continuing_subword_prefix" not in kwargs
and tokenizer_json["model"]["continuing_subword_prefix"] is not None
):
kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
if (
tokenizer_json["model"]["type"] == "BPE"
and "end_of_work_suffix" not in kwargs
and tokenizer_json["model"]["end_of_word_suffix"] is not None
):
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
if unk_token is not None:
# For Unigram tokenizers we need to set back the unk id of the model (bug in Tokenizers?)
trained_tokenizer_json = json.loads(tokenizer.to_str())
vocab = trained_tokenizer_json["model"]["vocab"]
unk_id = 0
while unk_id < len(vocab) and vocab[unk_id][0] != unk_token:
unk_id += 1
if unk_id < len(vocab):
trained_tokenizer_json["model"]["unk_id"] = unk_id
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
if post_processor is not None:
trained_tokenizer_json = json.loads(tokenizer.to_str())
# Almost done, we just have to adjust the token IDs in the post processor
if "special_tokens" in post_processor:
for key in post_processor["special_tokens"]:
tokens = post_processor["special_tokens"][key]["tokens"]
if special_tokens_map is not None:
tokens = [special_tokens_map.get(token, token) for token in tokens]
post_processor["special_tokens"][key]["tokens"] = tokens
post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
for special_token in ["cls", "sep"]:
if special_token in post_processor:
token, _ = post_processor[special_token]
if special_tokens_map is not None and token in special_tokens_map:
token = special_tokens_map[token]
token_id = tokenizer.token_to_id(token)
post_processor[special_token] = [token, token_id]
trained_tokenizer_json["post_processor"] = post_processor
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
kwargs = self.init_kwargs.copy()
# Map pad/cls/mask token at the Transformers level
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
special_tokens_list.remove("additional_special_tokens")
for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings.
if getattr(self, f"_{token}") is not None:
special_token = getattr(self, token)
if special_tokens_map is not None and special_token in special_tokens_map:
special_token = special_tokens_map[special_token]
special_token_full = getattr(self, f"_{token}")
if isinstance(special_token_full, AddedToken):
# Create an added token with the same paramters except the content
kwargs[token] = AddedToken(
special_token,
single_word=special_token_full.single_word,
lstrip=special_token_full.lstrip,
rstrip=special_token_full.rstrip,
normalized=special_token_full.normalized,
)
else:
kwargs[token] = special_token
additional_special_tokens = self.additional_special_tokens
if new_special_tokens is not None:
additional_special_tokens.extend(new_special_tokens)
if len(additional_special_tokens) > 0:
kwargs["additional_special_tokens"] = additional_special_tokens
return self.__class__(tokenizer_object=tokenizer, **kwargs)
...@@ -33,6 +33,7 @@ from transformers import ( ...@@ -33,6 +33,7 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
SpecialTokensMixin,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
) )
...@@ -57,6 +58,11 @@ if TYPE_CHECKING: ...@@ -57,6 +58,11 @@ if TYPE_CHECKING:
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"] NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
SMALL_TRAINING_CORPUS = [
["This is the first sentence.", "This is the second one."],
["This sentence (contains #) over symbols and numbers 12 3.", "But not this one."],
]
def filter_non_english(_, pretrained_name: str): def filter_non_english(_, pretrained_name: str):
"""Filter all the model for non-english language""" """Filter all the model for non-english language"""
...@@ -390,7 +396,11 @@ class TokenizerTesterMixin: ...@@ -390,7 +396,11 @@ class TokenizerTesterMixin:
tokenizer = self.get_rust_tokenizer() tokenizer = self.get_rust_tokenizer()
for parameter_name, parameter in signature.parameters.items(): for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty and parameter_name != "tokenizer_file": if parameter.default != inspect.Parameter.empty and parameter_name not in [
"vocab_file",
"merges_file",
"tokenizer_file",
]:
self.assertIn(parameter_name, tokenizer.init_kwargs) self.assertIn(parameter_name, tokenizer.init_kwargs)
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
...@@ -3144,6 +3154,146 @@ class TokenizerTesterMixin: ...@@ -3144,6 +3154,146 @@ class TokenizerTesterMixin:
self.assertTrue(special_token_id in p_output) self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output) self.assertTrue(special_token_id in cr_output)
def test_training_new_tokenizer(self):
# This feature only exists for fast tokenizers
if not self.test_rust_tokenizer:
return
tokenizer = self.get_rust_tokenizer()
new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
# Test we can use the new tokenizer with something not seen during training
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
self.assertEqual(len(inputs["input_ids"]), 2)
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
expected_result = "This is the first sentence"
# OpenAIGPT always lowercases and has no arg.
if new_tokenizer.init_kwargs.get("do_lower_case", False) or tokenizer.__class__.__name__.startswith(
"OpenAIGPT"
):
expected_result = expected_result.lower()
self.assertEqual(expected_result, decoded_input)
# We check that the parameters of the tokenizer remained the same
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
# Assert the set of special tokens match as we didn't ask to change them
self.assertSequenceEqual(
tokenizer.all_special_tokens_extended,
new_tokenizer.all_special_tokens_extended,
)
self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
def test_training_new_tokenizer_with_special_tokens_change(self):
# This feature only exists for fast tokenizers
if not self.test_rust_tokenizer:
return
tokenizer = self.get_rust_tokenizer()
# Test with a special tokens map
class_signature = inspect.signature(tokenizer.__class__)
if "cls_token" in class_signature.parameters:
new_tokenizer = tokenizer.train_new_from_iterator(
SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: "<cls>"}
)
cls_id = new_tokenizer.get_vocab()["<cls>"]
self.assertEqual(new_tokenizer.cls_token, "<cls>")
self.assertEqual(new_tokenizer.cls_token_id, cls_id)
# Create a new mapping from the special tokens defined in the original tokenizer
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
special_tokens_list.remove("additional_special_tokens")
special_tokens_map = {}
for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None:
special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a"
# Train new tokenizer
new_tokenizer = tokenizer.train_new_from_iterator(
SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
)
# Check the changes
for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None:
continue
special_token = getattr(tokenizer, token)
if special_token in special_tokens_map:
new_special_token = getattr(new_tokenizer, token)
self.assertEqual(special_tokens_map[special_token], new_special_token)
new_id = new_tokenizer.get_vocab()[new_special_token]
self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
# Check if the AddedToken / string format has been kept
for special_token in tokenizer.all_special_tokens_extended:
if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
self.assertTrue(
special_token in new_tokenizer.all_special_tokens_extended,
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
)
elif isinstance(special_token, AddedToken):
# The special token must appear in the list of the new tokenizer as an object of type AddedToken with
# the same parameters as the old AddedToken except the content that the user has requested to change.
special_token_str = special_token.content
new_special_token_str = special_tokens_map[special_token_str]
find = False
for candidate in new_tokenizer.all_special_tokens_extended:
if (
isinstance(candidate, AddedToken)
and candidate.content == new_special_token_str
and candidate.lstrip == special_token.lstrip
and candidate.rstrip == special_token.rstrip
and candidate.normalized == special_token.normalized
and candidate.single_word == special_token.single_word
):
find = True
break
self.assertTrue(
find,
(
f"'{new_special_token_str}' doesn't appear in the list "
f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
),
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
self.assertTrue(
special_token in new_tokenizer.all_special_tokens_extended,
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
)
else:
# The special token must appear in the list of the new tokenizer as an object of type string.
self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended)
# Test we can use the new tokenizer with something not seen during training
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
self.assertEqual(len(inputs["input_ids"]), 2)
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
expected_result = "This is the first sentence"
# OpenAIGPT always lowercases and has no arg.
if new_tokenizer.init_kwargs.get("do_lower_case", False) or tokenizer.__class__.__name__.startswith(
"OpenAIGPT"
):
expected_result = expected_result.lower()
self.assertEqual(expected_result, decoded_input)
@is_staging_test @is_staging_test
class TokenizerPushToHubTester(unittest.TestCase): class TokenizerPushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment