Unverified Commit 2da88537 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

🚨🚨 🚨🚨 [`Tokenizer`] attemp to fix add_token issues🚨🚨 🚨🚨 (#23909)



* fix test for bart. Order is correct now let's skip BPEs

* ouf

* styling

* fix bert....

* slow refactoring

* current updates

* massive refactoring

* update

* NICE!

* update to see where I am at

* updates

* update

* update

* revert

* updates

* updates

* start supporting legacy_save

* styling

* big update

* revert some changes

* nits

* nniiiiiice

* small fixes

* kinda fix t5 with new behaviour

* major update

* fixup

* fix copies

* today's updates

* fix byt5

* upfate

* update

* update

* updates

* update vocab size test

* Barthez does not use not need the fairseq offset ids

* super calll must be after

* calll super

* move all super init

* move other super init

* fixup

* nits

* more fixes

* nits

* more fixes

* nits

* more fix

* remove useless files

* ouch all of them are affected

* and more!

* small imporvements

* no more sanitize token

* more changes around unique no split tokens

* partially fix more things

* keep legacy save but add warning

* so... more fixes

* updates

* guess deberta tokenizer could be nuked

* fixup

* fixup did some bad things

* nuke it if it breaks

* remove prints and pretrain fast from slow with new format.

* fixups

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fiou

* nit

* by default specials should not be normalized?

* update

* remove brakpoint

* updates

* a lot of updates

* fixup

* fixes revert some changes to match fast

* small nits

* that makes it cleaner

* fix camembert accordingly

* update

* some lest breaking changes

* update

* fixup

* fix byt5 and whisper mostly

* some more fixes, canine's byte vocab

* fix gpt2

* fix most of the perceiver tests (4 left)

* fix layout lmv3

* fixup

* fix copies for gpt2 style

* make sure to only warn once

* fix perciever and gpt2 tests

* some more backward compatibility: also read special tokens map because some ppl use it........////.....

* fixup

* add else when reading

* nits

* fresh updates

* fix copies

* will this make everything faster?

* fixes

* more fixes

* update

* more fixes

* fixup

* is the source of truth right?

* sorry camembert for the troubles

* current updates

* fixup

* update led

* update

* fix regression

* fix single word

* more model specific fixes

* fix t5 tests

* fixup

* more comments

* update

* fix nllb

* rstrip removed

* small fixes

* better handle additional_special_tokens and vocab sizes

* fixing

* styling

* fix 4 / 21

* fixup

* fix nlbb's tests

* some fixes

* fix t5

* fixes

* style

* fix canine tests

* damn this is nice

* nits

* m2m100 nit

* fixups

* fixes!

* fixup

* stash

* fix merge

* revert bad change

* fixup

* correct order for code Llama

* fix speecht5 post merge

* styling

* revert source of 11 fails

* small nits

* all changes in one go

* fnet hack

* fix 2 more tests

* update based on main branch of tokenizers

* fixup

* fix VITS issues

* more fixes

* fix mgp test

* fix camembert issues

* oups camembert still has 2 failing tests

* mluke fixes

* decode fixes

* small nits

* nits

* fix llama and vits

* fix camembert

* smal nits

* more fixes when initialising a fast from a slow and etc

* fix one of the last test

* fix CPM tokenizer test

* fixups

* fix pop2piano

* fixup

* ️ Change tokenizers required version ️

* ️ Change tokenizers required version ️

* "tokenizers>=0.14,<0.15", don't forget smaller than

* fix musicgen tests and pretraiendtokenizerfast

* fix owlvit and all

* update t5

* fix 800 red

* fix tests

* fix the fix of the fix of t5

* styling

* documentation nits

* cache _added_tokens_encoder

* fixups

* Nit

* fix red tests

* one last nit!

* make eveything a lot simpler

* Now it's over 😉



* few small nits

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* updates that work for now

* tests that should no be skipped / changed and fixed next

* fixup

* i am ashamed

* pushe the fix

* update

* fixups

* nits

* fix added_tokens_encoder

* fix canine test

* fix pegasus vocab

* fix transfoXL

* fixup

* whisper needs to be fixed for train new

* pegasus nits

* more pegasus fixes

* minor update

* better error message in failed test

* fix whisper failing test

* fix whisper failing test

* fix pegasus

* fixup

* fix **** pegasus

* reset things

* remove another file

* attempts to fix the strange custome encoder and offset

* nits here and there

* update

* fixup

* nit

* fix the whisper test

* nits nits

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* updates based on review

* some small update to potentially remove

* nits

* import rlu cache

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* move warning to `from_pretrained`

* update tests results now that the special tokens are always added

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 835b0a05
......@@ -354,21 +354,6 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
strip_accents: Optional[bool] = None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
x_sep_token=x_sep_token,
pad_token=pad_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self.unique_no_split_tokens.append(x_sep_token)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
......@@ -384,7 +369,21 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
x_sep_token=x_sep_token,
pad_token=pad_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def vocab_size(self):
......
......@@ -157,20 +157,6 @@ class RealmTokenizer(PreTrainedTokenizer):
strip_accents=None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
......@@ -186,7 +172,20 @@ class RealmTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
......
......@@ -106,6 +106,10 @@ class ReformerTokenizer(PreTrainedTokenizer):
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
......@@ -114,10 +118,6 @@ class ReformerTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
......
......@@ -111,6 +111,13 @@ class RemBertTokenizer(PreTrainedTokenizer):
mask_token="[MASK]",
**kwargs,
):
self.do_lower_case = do_lower_case
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
super().__init__(
do_lower_case=do_lower_case,
remove_space=remove_space,
......@@ -125,14 +132,6 @@ class RemBertTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.do_lower_case = do_lower_case
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return len(self.sp_model)
......
......@@ -203,28 +203,21 @@ class RobertaTokenizer(PreTrainedTokenizer):
**kwargs,
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
mask_token = (
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
if isinstance(mask_token, str)
else mask_token
)
# these special tokens are not part of the vocab.json, let's add them in the correct order
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
......@@ -241,12 +234,27 @@ class RobertaTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
vocab = dict(self.encoder).copy()
vocab.update(self.added_tokens_encoder)
return vocab
def bpe(self, token):
if token in self.cache:
......
......@@ -177,6 +177,11 @@ class RobertaTokenizerFast(PreTrainedTokenizerFast):
trim_offsets=True,
**kwargs,
):
mask_token = (
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
if isinstance(mask_token, str)
else mask_token
)
super().__init__(
vocab_file,
merges_file,
......
......@@ -156,20 +156,6 @@ class RoCBertTokenizer(PreTrainedTokenizer):
strip_accents=None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
for cur_file in [vocab_file, word_shape_file, word_pronunciation_file]:
if cur_file is None or not os.path.isfile(cur_file):
raise ValueError(
......@@ -195,7 +181,20 @@ class RoCBertTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = RoCBertWordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = RoCBertWordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
......
......@@ -378,20 +378,6 @@ class RoFormerTokenizer(PreTrainedTokenizer):
strip_accents=None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
......@@ -407,7 +393,7 @@ class RoFormerTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
try:
import rjieba
except ImportError:
......@@ -417,6 +403,20 @@ class RoFormerTokenizer(PreTrainedTokenizer):
)
self.jieba = rjieba
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
......
......@@ -122,23 +122,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
do_lower_case=False,
tgt_lang=None,
lang_codes=None,
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
do_upper_case=do_upper_case,
do_lower_case=do_lower_case,
tgt_lang=tgt_lang,
lang_codes=lang_codes,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
self.do_upper_case = do_upper_case
self.do_lower_case = do_lower_case
......@@ -152,18 +141,39 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
self.langs = LANGUAGES[lang_codes]
self.lang_tokens = [f"<lang:{lang}>" for lang in self.langs]
self.lang_code_to_id = {lang: self.sp_model.PieceToId(f"<lang:{lang}>") for lang in self.langs}
self._additional_special_tokens = self.lang_tokens
if additional_special_tokens is not None:
additional_special_tokens = self.lang_tokens + additional_special_tokens
else:
additional_special_tokens = self.lang_tokens
self._tgt_lang = tgt_lang if tgt_lang is not None else self.langs[0]
self.set_tgt_lang_special_tokens(self._tgt_lang)
else:
self.lang_code_to_id = {}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
do_upper_case=do_upper_case,
do_lower_case=do_lower_case,
tgt_lang=tgt_lang,
lang_codes=lang_codes,
sp_model_kwargs=self.sp_model_kwargs,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()
vocab.update(self.added_tokens_encoder)
return vocab
@property
def tgt_lang(self) -> str:
return self._tgt_lang
......@@ -241,11 +251,6 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["sp_model"] = None
......
......@@ -110,15 +110,6 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
merges_file=None,
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
**kwargs,
)
self.do_lower_case = do_lower_case
with open(vocab_file, encoding="utf-8") as vocab_handle:
......@@ -137,6 +128,14 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
**kwargs,
)
@property
def vocab_size(self) -> int:
......
......@@ -105,6 +105,12 @@ class SpeechT5Tokenizer(PreTrainedTokenizer):
**kwargs,
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.normalize = normalize
self._normalizer = None
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
super().__init__(
bos_token=bos_token,
......@@ -116,13 +122,6 @@ class SpeechT5Tokenizer(PreTrainedTokenizer):
**kwargs,
)
self.vocab_file = vocab_file
self.normalize = normalize
self._normalizer = None
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
normalize = kwargs.pop("normalize", self.normalize)
if is_split_into_words:
......
......@@ -137,20 +137,6 @@ class SplinterTokenizer(PreTrainedTokenizer):
strip_accents=None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
......@@ -166,8 +152,21 @@ class SplinterTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
self.question_token = question_token
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def question_token_id(self):
......
......@@ -138,20 +138,6 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
strip_accents=None,
**kwargs,
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
......@@ -167,7 +153,22 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
......
......@@ -25,6 +25,7 @@ import sentencepiece as spm
from ...convert_slow_tokenizer import import_protobuf
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import AddedToken
if TYPE_CHECKING:
......@@ -152,18 +153,37 @@ class T5Tokenizer(PreTrainedTokenizer):
legacy=None,
**kwargs,
) -> None:
# Add extra_ids to the special token list
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra_id special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
pad_token = AddedToken(pad_token, rstrip=True, lstrip=True)
unk_token = AddedToken(unk_token, rstrip=True, lstrip=True)
eos_token = AddedToken(eos_token, rstrip=True, lstrip=True)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
if additional_special_tokens is not None:
extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
if extra_ids > 0 and extra_ids != len(extra_tokens):
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
" provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
" tokens"
)
else:
extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
additional_special_tokens = extra_tokens
# for legacy purpose, we keep this. Will be removed and tests updated. (when `added_tokens_decoder` is not passed as kwargs)
self._added_tokens_decoder = {}
for i in range(len(extra_tokens)):
self._added_tokens_decoder[len(self.sp_model) - 1 + extra_ids - i] = AddedToken(
f"<extra_id_{i}>", single_word=True, lstrip=True, rstrip=True, special=True
)
if legacy is None:
logger.warning_once(
f"You are using the default legacy behaviour of the {self.__class__}. If you see this, DO NOT PANIC! This is"
......@@ -175,7 +195,9 @@ class T5Tokenizer(PreTrainedTokenizer):
legacy = True
self.legacy = legacy
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model = self.get_spm_processor()
self.vocab_file = vocab_file
self._extra_ids = extra_ids
super().__init__(
eos_token=eos_token,
......@@ -188,11 +210,6 @@ class T5Tokenizer(PreTrainedTokenizer):
**kwargs,
)
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = self.get_spm_processor()
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
......@@ -234,7 +251,7 @@ class T5Tokenizer(PreTrainedTokenizer):
@property
def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids
return self.sp_model.get_piece_size()
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
......@@ -275,7 +292,7 @@ class T5Tokenizer(PreTrainedTokenizer):
)
def get_sentinel_token_ids(self):
return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()]
return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
......@@ -391,18 +408,11 @@ class T5Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"):
match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1))
return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
else:
token = f"<extra_id_{self.vocab_size - 1 - index}>"
return token
def convert_tokens_to_string(self, tokens):
......
......@@ -31,6 +31,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING,
VERY_LARGE_INTEGER,
BatchEncoding,
EncodedInput,
PreTokenizedInput,
......@@ -351,6 +352,44 @@ class TapasTokenizer(PreTrainedTokenizer):
else:
additional_special_tokens = [empty_token]
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
# Additional properties
self.cell_trim_length = cell_trim_length
self.max_column_id = (
max_column_id
if max_column_id is not None
else model_max_length
if model_max_length is not None
else VERY_LARGE_INTEGER
)
self.max_row_id = (
max_row_id
if max_row_id is not None
else model_max_length
if model_max_length is not None
else VERY_LARGE_INTEGER
)
self.strip_column_names = strip_column_names
self.update_answer_coordinates = update_answer_coordinates
self.min_question_length = min_question_length
self.max_question_length = max_question_length
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
......@@ -375,32 +414,6 @@ class TapasTokenizer(PreTrainedTokenizer):
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
# Additional properties
self.cell_trim_length = cell_trim_length
self.max_column_id = max_column_id if max_column_id is not None else self.model_max_length
self.max_row_id = max_row_id if max_row_id is not None else self.model_max_length
self.strip_column_names = strip_column_names
self.update_answer_coordinates = update_answer_coordinates
self.min_question_length = min_question_length
self.max_question_length = max_question_length
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
......
......@@ -181,25 +181,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
language="en",
**kwargs,
):
super().__init__(
special=special,
min_freq=min_freq,
max_size=max_size,
lower_case=lower_case,
delimiter=delimiter,
vocab_file=vocab_file,
pretrained_vocab_file=pretrained_vocab_file,
never_split=never_split,
unk_token=unk_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
language=language,
**kwargs,
)
requires_backends(self, "sacremoses")
if never_split is None:
never_split = self.all_special_tokens
if special is None:
special = []
self.counter = Counter()
......@@ -209,7 +191,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file
self.never_split = never_split
self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
self.punction_without_space_before_pattern = re.compile(rf"[^\s][{self.punctuation_symbols}]")
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
......@@ -217,7 +198,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.moses_punct_normalizer = sm.MosesPunctNormalizer(language)
self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language)
self.idx2sym = []
self.sym2idx = OrderedDict()
# This try... catch... is not beautiful but honestly this tokenizer was not made to be used
# in a library like ours, at all.
try:
......@@ -241,7 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_dict is not None:
for key, value in vocab_dict.items():
if key not in self.__dict__:
if key not in self.__dict__ or key == "sym2idx":
self.__dict__[key] = value
elif vocab_file is not None:
self.build_vocab()
......@@ -256,6 +238,27 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None:
self.build_vocab()
super().__init__(
special=special,
min_freq=min_freq,
max_size=max_size,
lower_case=lower_case,
delimiter=delimiter,
vocab_file=vocab_file,
pretrained_vocab_file=pretrained_vocab_file,
never_split=never_split,
unk_token=unk_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
language=language,
**kwargs,
)
# these are not required to initialize the parent class as only used when tokenizing.
if never_split is None:
never_split = self.all_special_tokens
self.never_split = never_split
@property
def do_lower_case(self):
return self.lower_case
......@@ -305,7 +308,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
elif "<unk>" in self.sym2idx:
self.unk_idx = self.sym2idx["<unk>"]
else:
raise ValueError("No <unknown> token in vocabulary")
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement.")
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if os.path.isdir(save_directory):
......@@ -323,7 +326,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if self.vocab_file:
logger.info(f"building vocab from {self.vocab_file}")
self._build_from_file(self.vocab_file)
logger.info(f"final vocab size {len(self)}")
logger.info(f"Final vocab size {len(self.sym2idx)}")
else:
logger.info(f"building vocab with min_freq={self.min_freq}, max_size={self.max_size}")
self.idx2sym = []
......@@ -337,7 +340,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
break
self.add_symbol(sym)
logger.info(f"final vocab size {len(self)} from {len(self.counter)} unique tokens")
logger.info(f"Final vocab size {len(self.sym2idx)} from {len(self.counter)} unique tokens")
@torch_only_method
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
......@@ -406,9 +409,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.sym2idx[current_sym] = idx
# Delete token from added_tokens
old_index = self.added_tokens_encoder[token]
del self.added_tokens_decoder[old_index]
del self.added_tokens_encoder[token]
old_index = self._added_tokens_encoder.pop(token)
self._added_tokens_decoder.pop(old_index)
def moses_punct_norm(self, text):
return self.moses_punct_normalizer.normalize(text)
......@@ -463,7 +465,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
elif "<UNK>" in self.sym2idx:
return self.sym2idx["<UNK>"]
else:
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement.")
def convert_tokens_to_string(self, tokens):
"""
......@@ -482,7 +484,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return len(self.idx2sym)
def get_vocab(self):
return dict(self.sym2idx, **self.added_tokens_encoder)
vocab = self.sym2idx.copy()
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
......
......@@ -93,17 +93,6 @@ class VitsTokenizer(PreTrainedTokenizer):
is_uroman=False,
**kwargs,
) -> None:
super().__init__(
pad_token=pad_token,
unk_token=unk_token,
language=language,
add_blank=add_blank,
normalize=normalize,
phonemize=phonemize,
is_uroman=is_uroman,
**kwargs,
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
......@@ -115,12 +104,24 @@ class VitsTokenizer(PreTrainedTokenizer):
self.is_uroman = is_uroman
super().__init__(
pad_token=pad_token,
unk_token=unk_token,
language=language,
add_blank=add_blank,
normalize=normalize,
phonemize=phonemize,
is_uroman=is_uroman,
**kwargs,
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def normalize_text(self, input_string):
......
......@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...utils import (
ModelOutput,
......@@ -174,18 +174,6 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
target_lang=None,
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
word_delimiter_token=word_delimiter_token,
replace_word_delimiter_char=replace_word_delimiter_char,
target_lang=target_lang,
**kwargs,
)
self._word_delimiter_token = word_delimiter_token
self.do_lower_case = do_lower_case
......@@ -204,13 +192,28 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self.decoder = {v: k for k, v in self.encoder.items()}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
word_delimiter_token=word_delimiter_token,
replace_word_delimiter_char=replace_word_delimiter_char,
target_lang=target_lang,
**kwargs,
)
# make sure that tokens made of several
# characters are not split at tokenization
# TODO @ArthurZ add them or just update the trie?
unique_no_split_tokens = []
for token in self.encoder.keys():
if len(token) > 1:
self.unique_no_split_tokens.append(token)
unique_no_split_tokens.append(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
self._create_trie(self.unique_no_split_tokens)
self.add_tokens(unique_no_split_tokens)
def set_target_lang(self, target_lang: str):
"""
......@@ -266,7 +269,20 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
return len(self.decoder)
def get_vocab(self) -> Dict:
return dict(self.vocab, **self.added_tokens_encoder)
vocab = dict(self.encoder)
vocab.update(self.added_tokens_encoder)
return vocab
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
# Overwritten to never strip!
to_add = []
for token in new_tokens:
if isinstance(token, str):
to_add.append(AddedToken(token, rstrip=False, lstrip=False, normalize=False))
else:
to_add.append(token)
return super()._add_tokens(to_add, special_tokens)
def _tokenize(self, text, **kwargs):
"""
......@@ -645,64 +661,6 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
return (vocab_file,)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
Args:
new_tokens (`List[str]`or `List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the `unk_token` to them).
special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the tokens should be added as special tokens.
Returns:
`int`: The number of tokens actually added to the vocabulary.
Example:
```python
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print("We have added", num_added_toks, "tokens")
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
```"""
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before)
for token in tokens_to_add:
if len(token) > 1:
self._additional_special_tokens.append(AddedToken(token))
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
self._create_trie(self.unique_no_split_tokens)
return len(tokens_to_add)
class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"""
......@@ -777,18 +735,6 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
return_attention_mask=False,
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
do_normalize=do_normalize,
return_attention_mask=return_attention_mask,
word_delimiter_token=word_delimiter_token,
**kwargs,
)
warnings.warn(
"The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use"
" `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
......@@ -806,6 +752,18 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
self.decoder = {v: k for k, v in self.encoder.items()}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
do_normalize=do_normalize,
return_attention_mask=return_attention_mask,
word_delimiter_token=word_delimiter_token,
**kwargs,
)
@property
def word_delimiter_token(self) -> str:
"""
......
......@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import AddedToken
from ...utils import (
ModelOutput,
......@@ -143,19 +143,6 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
phonemizer_backend="espeak",
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
word_delimiter_token=word_delimiter_token,
phone_delimiter_token=phone_delimiter_token,
do_phonemize=do_phonemize,
phonemizer_lang=phonemizer_lang,
phonemizer_backend=phonemizer_backend,
**kwargs,
)
self._word_delimiter_token = word_delimiter_token
self._phone_delimiter_token = phone_delimiter_token
self.do_phonemize = do_phonemize
......@@ -168,13 +155,38 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
word_delimiter_token=word_delimiter_token,
phone_delimiter_token=phone_delimiter_token,
do_phonemize=do_phonemize,
phonemizer_lang=phonemizer_lang,
phonemizer_backend=phonemizer_backend,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.decoder)
def get_vocab(self) -> Dict:
return dict(self.encoder, **self.added_tokens_encoder)
vocab = dict(self.encoder)
vocab.update(self.added_tokens_encoder)
return vocab
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
# Overwritten to never strip!
to_add = []
for token in new_tokens:
if isinstance(token, str):
to_add.append(AddedToken(token, rstrip=False, lstrip=False, normalize=True))
else:
to_add.append(token)
return super()._add_tokens(to_add, special_tokens)
def init_backend(self, phonemizer_lang: str):
"""
......@@ -576,61 +588,3 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
Args:
new_tokens (`List[str]`or `List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the `unk_token` to them).
special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the tokens should be added as special tokens.
Returns:
`int`: The number of tokens actually added to the vocabulary.
Examples:
```python
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
model = Wav2Vec2PhonemeForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print("We have added", num_added_toks, "tokens")
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
```"""
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
if not isinstance(token, str):
raise ValueError(f"Token {token} has to be of type string, but is of type {type(token)}.")
assert isinstance(token, str)
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before)
for token in tokens_to_add:
if len(token) > 1:
self._additional_special_tokens.append(AddedToken(token))
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
self._create_trie(self.unique_no_split_tokens)
return len(tokens_to_add)
......@@ -272,18 +272,25 @@ class WhisperTokenizer(PreTrainedTokenizer):
predict_timestamps=False,
**kwargs,
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
......@@ -309,18 +316,28 @@ class WhisperTokenizer(PreTrainedTokenizer):
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.language = language
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
self.task = task
self.predict_timestamps = predict_timestamps
@property
def vocab_size(self) -> int:
return len(self.encoder)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe with GPT2 -> Whisper
def bpe(self, token):
if token in self.cache:
......@@ -390,11 +407,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
@property
def prefix_tokens(self) -> List[int]:
all_special_ids = self.all_special_ids
bos_token_id = all_special_ids[-106]
translate_token_id = all_special_ids[-6]
transcribe_token_id = all_special_ids[-5]
notimestamps_token_id = all_special_ids[-1]
bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
translate_token_id = self.convert_tokens_to_ids("<|translate|>")
transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
langs = tuple(LANGUAGES.keys())
if self.language is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment