Unverified Commit ade7371a authored by SaulLu's avatar SaulLu Committed by GitHub
Browse files

improve saving strategy of sentencepiece tokenizer (#15328)



* add new test

* add a feature to same the sentencepiece tokenizer model when the init file was deleted

* update marian

* update m2m_100

* fix marian

* update speech to text

* override test for layoutxlm

* fix saving bartpho

* remove harcoded values bartpho

* special token string version

* finish bartpho

* override layoutxml test

* add mbart

* move special tokens list

* format

* Revert "format"

This reverts commit 37a40df37903a932c2f951cbd33acb684246bae7.

* simplify list of string of special tokens

* Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 196cce6e
...@@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer): ...@@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer):
self.sp_model.Load(str(vocab_file)) self.sp_model.Load(str(vocab_file))
# Load the reduced vocab # Load the reduced vocab
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# Keep order of special tokens for backward compatibility
self.fairseq_tokens_to_ids = {}
cnt = 0
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
if str(token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(token)] = cnt
cnt += 1
with open(monolingual_vocab_file, "r", encoding="utf-8") as f: with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
for line in f.readlines(): for line in f.readlines():
token = line.strip().split()[0] token = line.strip().split()[0]
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids) self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.fairseq_tokens_to_ids) if str(mask_token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
...@@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer): ...@@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer):
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
else: else:
return self.fairseq_tokens_to_ids["<unk>"] return self.unk_token_id
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
...@@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer): ...@@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer):
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"], (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file): with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
out_monolingual_vocab_file
) and os.path.isfile(self.monolingual_vocab_file):
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file) copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
elif not os.path.isfile(self.monolingual_vocab_file):
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
for token in self.fairseq_tokens_to_ids:
if token not in self.all_special_tokens:
fp.write(f"{str(token)} \n")
return out_vocab_file, out_monolingual_vocab_file return out_vocab_file, out_monolingual_vocab_file
...@@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer): ...@@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer): ...@@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
......
...@@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer): ...@@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer): ...@@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for M2M100.""" """Tokenization classes for M2M100."""
import json import json
import os
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
...@@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path) save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists(): if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path) copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path)) return (str(vocab_save_path), str(spm_save_path))
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import os
import re import re
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
...@@ -23,7 +23,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -23,7 +23,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece import sentencepiece
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
"source_spm": "source.spm", "source_spm": "source.spm",
...@@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer):
return len(self.encoder) return len(self.encoder)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
save_dir = Path(save_directory) if not os.path.isdir(save_directory):
assert save_dir.is_dir(), f"{save_directory} should be a directory" logger.error(f"Vocabulary path ({save_directory}) should be a directory")
save_json( return
self.encoder, saved_files = []
save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]), out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
) )
for orig, f in zip(["source.spm", "target.spm"], self.spm_files): save_json(self.encoder, out_vocab_file)
dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name) saved_files.append(out_vocab_file)
if not dest_path.exists():
copyfile(f, save_dir / orig) for spm_save_filename, spm_orig_path, spm_model in zip(
[VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]],
return tuple( self.spm_files,
save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names [self.spm_source, self.spm_target],
) ):
spm_save_path = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename
)
if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path):
copyfile(spm_orig_path, spm_save_path)
saved_files.append(spm_save_path)
elif not os.path.isfile(spm_orig_path):
with open(spm_save_path, "wb") as fi:
content_spiece_model = spm_model.serialized_model_proto()
fi.write(content_spiece_model)
saved_files.append(spm_save_path)
return tuple(saved_files)
def get_vocab(self) -> Dict: def get_vocab(self) -> Dict:
vocab = self.encoder.copy() vocab = self.encoder.copy()
......
...@@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer): ...@@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
......
...@@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
......
...@@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer): ...@@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer): ...@@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Speech2Text.""" """Tokenization classes for Speech2Text."""
import json import json
import os
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer): ...@@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path) save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists(): if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path) copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path)) return (str(vocab_save_path), str(spm_save_path))
......
...@@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}") elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer): ...@@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
......
...@@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)
...@@ -394,6 +394,33 @@ class TokenizerTesterMixin: ...@@ -394,6 +394,33 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs) self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
self.check_subword_sampling(tokenizer_new) self.check_subword_sampling(tokenizer_new)
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
text = "This is text to test the tokenizer."
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_model_input_names_signature(self): def test_model_input_names_signature(self):
accepted_model_main_input_names = [ accepted_model_main_input_names = [
"input_ids", # nlp models "input_ids", # nlp models
......
...@@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
words, boxes = self.get_words_and_boxes()
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(
words,
boxes=boxes,
)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment