Unverified Commit b9af152e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[tokenizer] sanitize saved config (#21483)

* [tokenizer] sanitize saved config

* rm config["name_or_path"] test
parent 67d07487
...@@ -2153,6 +2153,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2153,6 +2153,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if self._auto_class is not None: if self._auto_class is not None:
custom_object_save(self, save_directory, config=tokenizer_config) custom_object_save(self, save_directory, config=tokenizer_config)
# remove private information
if "name_or_path" in tokenizer_config:
tokenizer_config.pop("name_or_path")
with open(tokenizer_config_file, "w", encoding="utf-8") as f: with open(tokenizer_config_file, "w", encoding="utf-8") as f:
out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str) f.write(out_str)
......
...@@ -230,8 +230,6 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -230,8 +230,6 @@ class AutoTokenizerTest(unittest.TestCase):
# Check the class of the tokenizer was properly saved (note that it always saves the slow class). # Check the class of the tokenizer was properly saved (note that it always saves the slow class).
self.assertEqual(config["tokenizer_class"], "BertTokenizer") self.assertEqual(config["tokenizer_class"], "BertTokenizer")
# Check other keys just to make sure the config was properly saved /reloaded.
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
def test_new_tokenizer_registration(self): def test_new_tokenizer_registration(self):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment