Unverified Commit 3b39b906 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`TokenizerFast`] `can_save_slow_tokenizer` as a property for when...

[`TokenizerFast`] `can_save_slow_tokenizer` as a property for when `vocab_file`'s folder was removed (#25626)

* pad token should be None by default

* fix tests

* nits

* check if isfile vocabfile

* add warning if sp model folder was deleted

* save SPM when missing folder for sloz

* update the ` can_save_slow_tokenizer`  to be a property

* first batch

* second batch

* missing one
parent 99fc3ac8
......@@ -165,7 +165,10 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -146,7 +146,10 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -398,8 +398,12 @@ class BertweetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
copyfile(self.merges_file, out_merge_file)
......
......@@ -150,7 +150,10 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -140,7 +140,10 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -141,7 +141,6 @@ class CpmTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
try:
import jieba
......@@ -153,6 +152,10 @@ class CpmTokenizerFast(PreTrainedTokenizerFast):
self.jieba = jieba
self.translator = str.maketrans(" \n", "\u2582\u2583")
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
# Copied from transformers.models.xlnet.tokenization_xlnet_fast.XLNetTokenizerFast.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -148,7 +148,10 @@ class DebertaV2TokenizerFast(PreTrainedTokenizerFast):
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
......
......@@ -132,7 +132,10 @@ class FNetTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -257,7 +257,6 @@ class LayoutXLMTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
# additional properties
self.cls_token_box = cls_token_box
......@@ -266,6 +265,10 @@ class LayoutXLMTokenizerFast(PreTrainedTokenizerFast):
self.pad_token_label = pad_token_label
self.only_label_first_subword = only_label_first_subword
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
@add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
def __call__(
self,
......
......@@ -128,7 +128,10 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
self.update_post_processor()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def update_post_processor(self):
"""
......
......@@ -129,7 +129,6 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
......@@ -149,6 +148,10 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
@property
def src_lang(self) -> str:
return self._src_lang
......
......@@ -147,7 +147,6 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
self.lang_code_to_id = {
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
......@@ -158,6 +157,10 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.set_src_lang_special_tokens(self._src_lang)
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
@property
def src_lang(self) -> str:
return self._src_lang
......
......@@ -1494,8 +1494,12 @@ class MLukeTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
entity_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"]
......
......@@ -175,7 +175,6 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
......@@ -195,6 +194,10 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
@property
def src_lang(self) -> str:
return self._src_lang
......
......@@ -152,7 +152,10 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
**kwargs,
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
......
......@@ -324,8 +324,12 @@ class PhobertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
copyfile(self.merges_file, out_merge_file)
......
......@@ -110,7 +110,10 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
......
......@@ -263,7 +263,11 @@ class RemBertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
......@@ -139,7 +139,10 @@ class RemBertTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
......@@ -142,9 +142,12 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
self._extra_ids = extra_ids
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
@staticmethod
def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
if pretrained_model_name_or_path in T5TokenizerFast.max_model_input_sizes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment