Unverified Commit dc42e770 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Easily train a new fast tokenizer from a given one (#12361)



* [WIP] Easily train a new fast tokenizer from a given one

* Fix test

* Roll out to other tokenizers and add tests

* Fix bug with unk id and add emoji to test

* Really use something different in test

* Implement special tokens map

* Map special tokens in the Transformers tokenizers

* Fix test

* Make test more robust

* Fix test for BPE

* More robust map and test

Co-authored-by SaulLu

* Test file

* Stronger tests
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>

* Map unk token for Wordpiece and address review comment

* Fix lowercase test and address review comment

* Fix all tests

* Simplify test

* Fix tests for realsies

* Easily train a new fast tokenizer from a given one - tackle the special tokens format (str or AddedToken) (#12420)

* Propose change in tests regarding lower case

* add new test for special tokens types

* put back the test part about decoding

* add feature: the AddedToken is re-build with the different mapped content

* Address review comment: simplify AddedToken building
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>

* Update src/transformers/tokenization_utils_fast.py
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSaulLu <lucilesaul.com@gmail.com>
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>
parent b440b8d1
...@@ -121,7 +121,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): ...@@ -121,7 +121,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=True, do_lower_case=True,
remove_space=True, remove_space=True,
......
...@@ -109,7 +109,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): ...@@ -109,7 +109,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
......
...@@ -162,7 +162,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ...@@ -162,7 +162,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=True, do_lower_case=True,
unk_token="[UNK]", unk_token="[UNK]",
......
...@@ -103,7 +103,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast): ...@@ -103,7 +103,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
unk_token="<unk>", unk_token="<unk>",
bos_token="<s>", bos_token="<s>",
......
...@@ -63,8 +63,8 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast): ...@@ -63,8 +63,8 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
unk_token="<|endoftext|>", unk_token="<|endoftext|>",
bos_token="<|endoftext|>", bos_token="<|endoftext|>",
eos_token="<|endoftext|>", eos_token="<|endoftext|>",
......
...@@ -105,7 +105,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): ...@@ -105,7 +105,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
......
...@@ -105,8 +105,8 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast): ...@@ -105,8 +105,8 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
tokenizer_file=None, tokenizer_file=None,
unk_token="<|endoftext|>", unk_token="<|endoftext|>",
bos_token="<|startoftext|>", bos_token="<|startoftext|>",
......
...@@ -95,8 +95,8 @@ class DebertaTokenizerFast(GPT2TokenizerFast): ...@@ -95,8 +95,8 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
tokenizer_file=None, tokenizer_file=None,
errors="replace", errors="replace",
bos_token="[CLS]", bos_token="[CLS]",
......
...@@ -88,7 +88,7 @@ class FunnelTokenizerFast(BertTokenizerFast): ...@@ -88,7 +88,7 @@ class FunnelTokenizerFast(BertTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=True, do_lower_case=True,
unk_token="<unk>", unk_token="<unk>",
......
...@@ -125,8 +125,8 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): ...@@ -125,8 +125,8 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
tokenizer_file=None, tokenizer_file=None,
unk_token="<|endoftext|>", unk_token="<|endoftext|>",
bos_token="<|endoftext|>", bos_token="<|endoftext|>",
......
...@@ -67,8 +67,8 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast): ...@@ -67,8 +67,8 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
tokenizer_file=None, tokenizer_file=None,
cls_token="<s>", cls_token="<s>",
unk_token="<unk>", unk_token="<unk>",
......
...@@ -121,7 +121,10 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -121,7 +121,10 @@ class MBartTokenizer(XLMRobertaTokenizer):
self._additional_special_tokens = list(self.lang_code_to_id.keys()) self._additional_special_tokens = list(self.lang_code_to_id.keys())
if additional_special_tokens is not None: if additional_special_tokens is not None:
self._additional_special_tokens.extend(additional_special_tokens) # Only add those special tokens if they are not already there.
self._additional_special_tokens.extend(
[t for t in additional_special_tokens if t not in self._additional_special_tokens]
)
self._src_lang = src_lang if src_lang is not None else "en_XX" self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
......
...@@ -110,7 +110,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ...@@ -110,7 +110,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
src_lang=None, src_lang=None,
tgt_lang=None, tgt_lang=None,
tokenizer_file=None, tokenizer_file=None,
......
...@@ -113,10 +113,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -113,10 +113,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
suffix_tokens: List[int] = [] suffix_tokens: List[int] = []
def __init__( def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs self,
vocab_file=None,
tokenizer_file=None,
src_lang=None,
tgt_lang=None,
additional_special_tokens=None,
**kwargs
): ):
super().__init__( super().__init__(
*args, vocab_file=vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
src_lang=src_lang, src_lang=src_lang,
tgt_lang=tgt_lang, tgt_lang=tgt_lang,
...@@ -127,7 +133,10 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -127,7 +133,10 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy() _additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
if additional_special_tokens is not None: if additional_special_tokens is not None:
_additional_special_tokens.extend(additional_special_tokens) # Only add those special tokens if they are not already there.
_additional_special_tokens.extend(
[t for t in additional_special_tokens if t not in _additional_special_tokens]
)
self.add_special_tokens({"additional_special_tokens": _additional_special_tokens}) self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
......
...@@ -106,7 +106,7 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast): ...@@ -106,7 +106,7 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=True, do_lower_case=True,
bos_token="<s>", bos_token="<s>",
......
...@@ -64,7 +64,7 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): ...@@ -64,7 +64,7 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
model_input_names = ["input_ids", "attention_mask"] model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = OpenAIGPTTokenizer slow_tokenizer_class = OpenAIGPTTokenizer
def __init__(self, vocab_file, merges_file, tokenizer_file=None, unk_token="<unk>", **kwargs): def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<unk>", **kwargs):
super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs) super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs)
@property @property
......
...@@ -98,7 +98,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): ...@@ -98,7 +98,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
pad_token="<pad>", pad_token="<pad>",
eos_token="</s>", eos_token="</s>",
......
...@@ -87,7 +87,7 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast): ...@@ -87,7 +87,7 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
eos_token="</s>", eos_token="</s>",
unk_token="<unk>", unk_token="<unk>",
......
...@@ -143,8 +143,8 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -143,8 +143,8 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
merges_file, merges_file=None,
tokenizer_file=None, tokenizer_file=None,
errors="replace", errors="replace",
bos_token="<s>", bos_token="<s>",
......
...@@ -73,7 +73,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast): ...@@ -73,7 +73,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
def __init__( def __init__(
self, self,
vocab_file, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
do_lower_case=True, do_lower_case=True,
unk_token="[UNK]", unk_token="[UNK]",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment