Unverified Commit df6eee92 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Follow up for #31973 (#32025)



* fix

* [test_all] trigger full CI

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent de231889
...@@ -18,10 +18,10 @@ import os ...@@ -18,10 +18,10 @@ import os
import tempfile import tempfile
import unittest import unittest
import warnings import warnings
from pathlib import Path
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode from transformers.generation import GenerationMode
...@@ -228,72 +228,88 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -228,72 +228,88 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-generation-config")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass
def test_push_to_hub(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("test-generation-config", token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-generation-config") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="test-generation-config", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") config = GenerationConfig(
for k, v in config.to_dict().items(): do_sample=True,
if k != "transformers_version": temperature=0.7,
self.assertEqual(v, getattr(new_config, k)) length_penalty=1.0,
)
config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("valid_org/test-generation-config-org", token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
) config = GenerationConfig(
do_sample=True,
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") temperature=0.7,
for k, v in config.to_dict().items(): length_penalty=1.0,
if k != "transformers_version": )
self.assertEqual(v, getattr(new_config, k)) config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -20,10 +20,8 @@ import tempfile ...@@ -20,10 +20,8 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from uuid import uuid4
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from requests.exceptions import HTTPError
import transformers import transformers
from transformers import ( from transformers import (
...@@ -374,69 +372,73 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -374,69 +372,73 @@ class ProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-processor-org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-processor") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self): def test_push_to_hub_via_save_pretrained(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(os.path.join(tmp_dir, "test-processor"), push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-processor-{Path(tmp_dir).name}"
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor") processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
for k, v in processor.feature_extractor.__dict__.items(): # Push to hub via save_pretrained
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) processor.save_pretrained(tmp_repo, repo_id=tmp_repo, push_to_hub=True, token=self._token)
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
def test_push_to_hub_in_organization(self): for k, v in processor.feature_extractor.__dict__.items():
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained( try:
os.path.join(tmp_dir, "test-processor-org"), tmp_repo = f"valid_org/test-processor-org-{Path(tmp_dir).name}"
push_to_hub=True, processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
token=self._token,
organization="valid_org", # Push to hub via save_pretrained
) processor.save_pretrained(
tmp_dir,
repo_id=tmp_repo,
push_to_hub=True,
token=self._token,
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org") new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items(): for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_processor(self): def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
CustomTokenizer.register_for_auto_class() try:
CustomProcessor.register_for_auto_class() tmp_repo = f"{USER}/test-dynamic-processor-{Path(tmp_dir).name}"
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir: feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
processor = CustomProcessor(feature_extractor, tokenizer) with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
random_repo_id = f"{USER}/test-dynamic-processor-{uuid4()}" processor = CustomProcessor(feature_extractor, tokenizer)
try:
with tempfile.TemporaryDirectory() as tmp_dir: create_repo(tmp_repo, token=self._token)
create_repo(random_repo_id, token=self._token) repo = Repository(tmp_dir, clone_from=tmp_repo, token=self._token)
repo = Repository(tmp_dir, clone_from=random_repo_id, token=self._token)
processor.save_pretrained(tmp_dir) processor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the feature extractor config # This has added the proper auto_map field to the feature extractor config
...@@ -466,8 +468,10 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -466,8 +468,10 @@ class ProcessorPushToHubTester(unittest.TestCase):
repo.push_to_hub() repo.push_to_hub()
new_processor = AutoProcessor.from_pretrained(random_repo_id, trust_remote_code=True) new_processor = AutoProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module # Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor") self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
finally:
delete_repo(repo_id=random_repo_id) finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -98,88 +98,106 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -98,88 +98,106 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-config")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-config")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub("test-config", token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-config") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.to_dict().items(): config = BertConfig(
if k != "transformers_version": vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
self.assertEqual(v, getattr(new_config, k)) )
config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub("valid_org/test-config-org", token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, token=self._token) try:
tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
new_config = BertConfig.from_pretrained("valid_org/test-config-org") config = BertConfig(
for k, v in config.to_dict().items(): vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
if k != "transformers_version": )
self.assertEqual(v, getattr(new_config, k)) config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_config(self): def test_push_to_hub_dynamic_config(self):
CustomConfig.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
config = CustomConfig(attribute=42) try:
tmp_repo = f"{USER}/test-dynamic-config-{Path(tmp_dir).name}"
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
config.push_to_hub("test-dynamic-config", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"}) self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) new_config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "CustomConfig") self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42) self.assertEqual(new_config.attribute, 42)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ConfigTestUtils(unittest.TestCase): class ConfigTestUtils(unittest.TestCase):
......
...@@ -60,85 +60,91 @@ class FeatureExtractorPushToHubTester(unittest.TestCase): ...@@ -60,85 +60,91 @@ class FeatureExtractorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-feature-extractor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("test-feature-extractor", token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-feature-extractor") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained( try:
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, token=self._token tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("valid_org/test-feature-extractor", token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
for k, v in feature_extractor.__dict__.items(): feature_extractor.push_to_hub(tmp_repo, token=self._token)
self.assertEqual(v, getattr(new_feature_extractor, k))
try: new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
# Reset repo for k, v in feature_extractor.__dict__.items():
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor") self.assertEqual(v, getattr(new_feature_extractor, k))
except: # noqa E722 finally:
pass # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
# Push to hub via save_pretrained def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, token=self._token tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org") def test_push_to_hub_in_organization(self):
for k, v in feature_extractor.__dict__.items(): with tempfile.TemporaryDirectory() as tmp_dir:
self.assertEqual(v, getattr(new_feature_extractor, k)) try:
tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub(tmp_repo, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_feature_extractor(self): def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) try:
tmp_repo = f"{USER}/test-dynamic-feature-extractor-{Path(tmp_dir).name}"
feature_extractor.push_to_hub("test-dynamic-feature-extractor", token=self._token) CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# This has added the proper auto_map field to the config
self.assertDictEqual( feature_extractor.push_to_hub(tmp_repo, token=self._token)
feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"}, # This has added the proper auto_map field to the config
) self.assertDictEqual(
feature_extractor.auto_map,
new_feature_extractor = AutoFeatureExtractor.from_pretrained( {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True )
)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo, trust_remote_code=True)
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor") # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -71,88 +71,93 @@ class ImageProcessorPushToHubTester(unittest.TestCase): ...@@ -71,88 +71,93 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-image-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
except HTTPError:
pass
def test_push_to_hub(self):
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("test-image-processor", token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-image-processor") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained( try:
tmp_dir, repo_id="test-image-processor", push_to_hub=True, token=self._token tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
for k, v in image_processor.__dict__.items(): new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
self.assertEqual(v, getattr(new_image_processor, k)) for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("valid_org/test-image-processor", token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor")
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-image-processor")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org")
for k, v in image_processor.__dict__.items(): new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
self.assertEqual(v, getattr(new_image_processor, k)) for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_image_processor(self): def test_push_to_hub_dynamic_image_processor(self):
CustomImageProcessor.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR) try:
tmp_repo = f"{USER}/test-dynamic-image-processor-{Path(tmp_dir).name}"
image_processor.push_to_hub("test-dynamic-image-processor", token=self._token) CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# This has added the proper auto_map field to the config
self.assertDictEqual( image_processor.push_to_hub(tmp_repo, token=self._token)
image_processor.auto_map,
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"}, # This has added the proper auto_map field to the config
) self.assertDictEqual(
image_processor.auto_map,
new_image_processor = AutoImageProcessor.from_pretrained( {"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
f"{USER}/test-dynamic-image-processor", trust_remote_code=True )
)
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ImageProcessingUtilsTester(unittest.TestCase): class ImageProcessingUtilsTester(unittest.TestCase):
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -55,89 +55,103 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -55,89 +55,103 @@ class FlaxModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model-flax")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub("test-model-flax", token=self._token)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-model-flax") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax") config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
base_params = flatten_dict(unfreeze(model.params)) )
new_params = flatten_dict(unfreeze(new_model.params)) model = FlaxBertModel(config)
model.push_to_hub(tmp_repo, token=self._token)
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() new_model = FlaxBertModel.from_pretrained(tmp_repo)
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub("valid_org/test-model-flax-org", token=self._token)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
) config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org") )
model = FlaxBertModel(config)
base_params = flatten_dict(unfreeze(model.params)) model.push_to_hub(tmp_repo, token=self._token)
new_params = flatten_dict(unfreeze(new_model.params))
new_model = FlaxBertModel.from_pretrained(tmp_repo)
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() base_params = flatten_dict(unfreeze(model.params))
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def check_models_equal(model1, model2): def check_models_equal(model1, model2):
......
...@@ -23,6 +23,7 @@ import random ...@@ -23,6 +23,7 @@ import random
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from pathlib import Path
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -682,127 +683,149 @@ class TFModelPushToHubTester(unittest.TestCase): ...@@ -682,127 +683,149 @@ class TFModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model-tf")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
model = TFBertModel(config) config = BertConfig(
# Make sure model is properly initialized vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
model.build_in_name_scope() )
model = TFBertModel(config)
logging.set_verbosity_info() # Make sure model is properly initialized
logger = logging.get_logger("transformers.utils.hub") model.build_in_name_scope()
with CaptureLogger(logger) as cl:
model.push_to_hub("test-model-tf", token=self._token)
logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
try: logging.set_verbosity_info()
# Reset repo logger = logging.get_logger("transformers.utils.hub")
delete_repo(token=self._token, repo_id="test-model-tf") with CaptureLogger(logger) as cl:
except: # noqa E722 model.push_to_hub(tmp_repo, token=self._token)
pass logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
# Push to hub via save_pretrained new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf") # Push to hub via save_pretrained
models_equal = True model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = False models_equal = True
break for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(models_equal) if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_push_to_hub_callback(self): def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertForMaskedLM(config)
model.compile()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback( try:
output_dir=tmp_dir, tmp_repo = f"{USER}/test-model-tf-callback-{Path(tmp_dir).name}"
hub_model_id="test-model-tf-callback", config = BertConfig(
hub_token=self._token, vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback]) model = TFBertForMaskedLM(config)
model.compile()
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback") push_to_hub_callback = PushToHubCallback(
models_equal = True output_dir=tmp_dir,
for p1, p2 in zip(model.weights, new_model.weights): hub_model_id=tmp_repo,
if not tf.math.reduce_all(p1 == p2): hub_token=self._token,
models_equal = False )
break model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters) new_model = TFBertForMaskedLM.from_pretrained(tmp_repo)
tf_push_to_hub_params.pop("base_model_card_args") models_equal = True
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters) for p1, p2 in zip(model.weights, new_model.weights):
pt_push_to_hub_params.pop("deprecated_kwargs") if not tf.math.reduce_all(p1 == p2):
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params) models_equal = False
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
model = TFBertModel(config) config = BertConfig(
# Make sure model is properly initialized vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
model.build_in_name_scope() )
model = TFBertModel(config)
model.push_to_hub("valid_org/test-model-tf-org", token=self._token) # Make sure model is properly initialized
model.build_in_name_scope()
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
try: model.push_to_hub(tmp_repo, token=self._token)
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
except: # noqa E722
pass
# Push to hub via save_pretrained new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-tf-org") try:
tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org") # Push to hub via save_pretrained
models_equal = True model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = False models_equal = True
break for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(models_equal) if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -1876,142 +1876,168 @@ class ModelPushToHubTester(unittest.TestCase): ...@@ -1876,142 +1876,168 @@ class ModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass pass
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
model = BertModel(config) config = BertConfig(
model.push_to_hub("test-model", token=self._token) vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
new_model = BertModel.from_pretrained(f"{USER}/test-model") model = BertModel(config)
for p1, p2 in zip(model.parameters(), new_model.parameters()): model.push_to_hub(tmp_repo, token=self._token)
self.assertTrue(torch.equal(p1, p2))
try: new_model = BertModel.from_pretrained(tmp_repo)
# Reset repo for p1, p2 in zip(model.parameters(), new_model.parameters()):
delete_repo(token=self._token, repo_id="test-model") self.assertTrue(torch.equal(p1, p2))
except: # noqa E722 finally:
pass # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
# Push to hub via save_pretrained @unittest.skip(reason="This test is flaky")
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_description(self): def test_push_to_hub_with_description(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
model = BertModel(config) config = BertConfig(
COMMIT_DESCRIPTION = """ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
COMMIT_DESCRIPTION = """
The commit description supports markdown synthax see: The commit description supports markdown synthax see:
```python ```python
>>> form transformers import AutoConfig >>> form transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased") >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
``` ```
""" """
commit_details = model.push_to_hub( commit_details = model.push_to_hub(
"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION tmp_repo, use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
) )
self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION) self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
model = BertModel(config) config = BertConfig(
model.push_to_hub("valid_org/test-model-org", token=self._token) vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
new_model = BertModel.from_pretrained("valid_org/test-model-org") model = BertModel(config)
for p1, p2 in zip(model.parameters(), new_model.parameters()): model.push_to_hub(tmp_repo, token=self._token)
self.assertTrue(torch.equal(p1, p2))
try: new_model = BertModel.from_pretrained(tmp_repo)
# Reset repo for p1, p2 in zip(model.parameters(), new_model.parameters()):
delete_repo(token=self._token, repo_id="valid_org/test-model-org") self.assertTrue(torch.equal(p1, p2))
except: # noqa E722 finally:
pass # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
# Push to hub via save_pretrained @unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-org") try:
tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
new_model = BertModel.from_pretrained("valid_org/test-model-org") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_model(self): def test_push_to_hub_dynamic_model(self):
CustomConfig.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
CustomModel.register_for_auto_class() try:
tmp_repo = f"{USER}/test-dynamic-model-{Path(tmp_dir).name}"
config = CustomConfig(hidden_size=32) CustomConfig.register_for_auto_class()
model = CustomModel(config) CustomModel.register_for_auto_class()
model.push_to_hub("test-dynamic-model", token=self._token) config = CustomConfig(hidden_size=32)
# checks model = CustomModel(config)
self.assertDictEqual(
config.auto_map, model.push_to_hub(tmp_repo, token=self._token)
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"}, # checks
) self.assertDictEqual(
config.auto_map,
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
)
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module # Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
new_model = AutoModel.from_config(config, trust_remote_code=True) new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_tags(self): def test_push_to_hub_with_tags(self):
from huggingface_hub import ModelCard with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-model-with-tags-{Path(tmp_dir).name}"
from huggingface_hub import ModelCard
new_tags = ["tag-1", "tag-2"] new_tags = ["tag-1", "tag-2"]
CustomConfig.register_for_auto_class() CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class() CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32) config = CustomConfig(hidden_size=32)
model = CustomModel(config) model = CustomModel(config)
self.assertTrue(model.model_tags is None) self.assertTrue(model.model_tags is None)
model.add_model_tags(new_tags) model.add_model_tags(new_tags)
self.assertTrue(model.model_tags == new_tags) self.assertTrue(model.model_tags == new_tags)
model.push_to_hub("test-dynamic-model-with-tags", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") loaded_model_card = ModelCard.load(tmp_repo)
self.assertEqual(loaded_model_card.data.tags, new_tags) self.assertEqual(loaded_model_card.data.tags, new_tags)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_torch @require_torch
......
...@@ -118,110 +118,133 @@ class TokenizerPushToHubTester(unittest.TestCase): ...@@ -118,110 +118,133 @@ class TokenizerPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-tokenizer")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
except HTTPError:
pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub("test-tokenizer", token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
try: try:
# Reset repo # Reset repo
delete_repo(token=self._token, repo_id="test-tokenizer") delete_repo(repo_id=repo_id, token=token)
except: # noqa E722 except: # noqa E722
pass pass
# Push to hub via save_pretrained def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer") vocab_file = os.path.join(tmp_dir, "vocab.txt")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_file = os.path.join(tmp_dir, "vocab.txt")
tokenizer = BertTokenizer(vocab_file) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token) tokenizer = BertTokenizer(vocab_file)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
try: self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
# Reset repo finally:
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org") # Always (try to) delete the repo.
except: # noqa E722 self._try_delete_repo(repo_id=tmp_repo, token=self._token)
pass
def test_push_to_hub_in_organization_via_save_pretrained(self):
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
) vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_tokenizers @require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self): def test_push_to_hub_dynamic_tokenizer(self):
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) CustomTokenizer.register_for_auto_class()
tokenizer = CustomTokenizer(vocab_file)
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer # No fast custom tokenizer
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer") self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
# Fast and slow custom tokenizer @require_tokenizers
CustomTokenizerFast.register_for_auto_class() def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) CustomTokenizer.register_for_auto_class()
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir) # Fast and slow custom tokenizer
bert_tokenizer.save_pretrained(tmp_dir) CustomTokenizerFast.register_for_auto_class()
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
vocab_file = os.path.join(tmp_dir, "vocab.txt")
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast") bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = AutoTokenizer.from_pretrained( tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
) tokenizer.push_to_hub(tmp_repo, token=self._token)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer") tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(tmp_repo, use_fast=False, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class TrieTest(unittest.TestCase): class TrieTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment