Unverified Commit df6eee92 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Follow up for #31973 (#32025)



* fix

* [test_all] trigger full CI

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent de231889
...@@ -18,10 +18,10 @@ import os ...@@ -18,10 +18,10 @@ import os
import tempfile import tempfile
import unittest import unittest
import warnings import warnings
from pathlib import Path
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode from transformers.generation import GenerationMode
...@@ -228,72 +228,88 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -228,72 +228,88 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-generation-config") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig( config = GenerationConfig(
do_sample=True, do_sample=True,
temperature=0.7, temperature=0.7,
length_penalty=1.0, length_penalty=1.0,
) )
config.push_to_hub("test-generation-config", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-generation-config") config = GenerationConfig(
except: # noqa E722 do_sample=True,
pass temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
config.save_pretrained(tmp_dir, repo_id="test-generation-config", push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig( config = GenerationConfig(
do_sample=True, do_sample=True,
temperature=0.7, temperature=0.7,
length_penalty=1.0, length_penalty=1.0,
) )
config.push_to_hub("valid_org/test-generation-config-org", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
try: def test_push_to_hub_in_organization_via_save_pretrained(self):
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
) )
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -20,10 +20,8 @@ import tempfile ...@@ -20,10 +20,8 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from uuid import uuid4
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from requests.exceptions import HTTPError
import transformers import transformers
from transformers import ( from transformers import (
...@@ -374,50 +372,57 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -374,50 +372,57 @@ class ProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-processor") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-processor-org")
except HTTPError:
pass pass
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-processor") tmp_repo = f"{USER}/test-processor-{Path(tmp_dir).name}"
except HTTPError:
pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: # Push to hub via save_pretrained
processor.save_pretrained(os.path.join(tmp_dir, "test-processor"), push_to_hub=True, token=self._token) processor.save_pretrained(tmp_repo, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor") new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items(): for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-processor-org-{Path(tmp_dir).name}"
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: # Push to hub via save_pretrained
processor.save_pretrained( processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"), tmp_dir,
repo_id=tmp_repo,
push_to_hub=True, push_to_hub=True,
token=self._token, token=self._token,
organization="valid_org",
) )
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org") new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items(): for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_processor(self): def test_push_to_hub_dynamic_processor(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-processor-{Path(tmp_dir).name}"
CustomFeatureExtractor.register_for_auto_class() CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class() CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class() CustomProcessor.register_for_auto_class()
...@@ -432,11 +437,8 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -432,11 +437,8 @@ class ProcessorPushToHubTester(unittest.TestCase):
processor = CustomProcessor(feature_extractor, tokenizer) processor = CustomProcessor(feature_extractor, tokenizer)
random_repo_id = f"{USER}/test-dynamic-processor-{uuid4()}" create_repo(tmp_repo, token=self._token)
try: repo = Repository(tmp_dir, clone_from=tmp_repo, token=self._token)
with tempfile.TemporaryDirectory() as tmp_dir:
create_repo(random_repo_id, token=self._token)
repo = Repository(tmp_dir, clone_from=random_repo_id, token=self._token)
processor.save_pretrained(tmp_dir) processor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the feature extractor config # This has added the proper auto_map field to the feature extractor config
...@@ -466,8 +468,10 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -466,8 +468,10 @@ class ProcessorPushToHubTester(unittest.TestCase):
repo.push_to_hub() repo.push_to_hub()
new_processor = AutoProcessor.from_pretrained(random_repo_id, trust_remote_code=True) new_processor = AutoProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module # Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor") self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
finally: finally:
delete_repo(repo_id=random_repo_id) # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -98,88 +98,106 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -98,88 +98,106 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-config")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-config-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-config") tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
config.push_to_hub("test-config", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config") new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-config")
except: # noqa E722
pass
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config") new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
config.push_to_hub("valid_org/test-config-org", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org") new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-config-org") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
config.save_pretrained(tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org") new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_config(self): def test_push_to_hub_dynamic_config(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-config-{Path(tmp_dir).name}"
CustomConfig.register_for_auto_class() CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42) config = CustomConfig(attribute=42)
config.push_to_hub("test-dynamic-config", token=self._token) config.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"}) self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) new_config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "CustomConfig") self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42) self.assertEqual(new_config.attribute, 42)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ConfigTestUtils(unittest.TestCase): class ConfigTestUtils(unittest.TestCase):
......
...@@ -60,76 +60,81 @@ class FeatureExtractorPushToHubTester(unittest.TestCase): ...@@ -60,76 +60,81 @@ class FeatureExtractorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-feature-extractor")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor") tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
except HTTPError:
pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("test-feature-extractor", token=self._token) feature_extractor.push_to_hub(tmp_repo, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor") new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items(): for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k)) self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-feature-extractor") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
except: # noqa E722
pass
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
feature_extractor.save_pretrained(
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor") new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items(): for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k)) self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("valid_org/test-feature-extractor", token=self._token) feature_extractor.push_to_hub(tmp_repo, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor") new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items(): for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k)) self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
except: # noqa E722
pass
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
feature_extractor.save_pretrained(
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org") new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items(): for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k)) self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_feature_extractor(self): def test_push_to_hub_dynamic_feature_extractor(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-feature-extractor-{Path(tmp_dir).name}"
CustomFeatureExtractor.register_for_auto_class() CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("test-dynamic-feature-extractor", token=self._token) feature_extractor.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual( self.assertDictEqual(
...@@ -137,8 +142,9 @@ class FeatureExtractorPushToHubTester(unittest.TestCase): ...@@ -137,8 +142,9 @@ class FeatureExtractorPushToHubTester(unittest.TestCase):
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"}, {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
) )
new_feature_extractor = AutoFeatureExtractor.from_pretrained( new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo, trust_remote_code=True)
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor") self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -71,76 +71,80 @@ class ImageProcessorPushToHubTester(unittest.TestCase): ...@@ -71,76 +71,80 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-image-processor")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("test-image-processor", token=self._token) image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-image-processor") image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
except: # noqa E722
pass
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
image_processor.save_pretrained(
tmp_dir, repo_id="test-image-processor", push_to_hub=True, token=self._token
)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("valid_org/test-image-processor", token=self._token) image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-image-processor") image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
except: # noqa E722
pass
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
image_processor.save_pretrained(
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, token=self._token
)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_image_processor(self): def test_push_to_hub_dynamic_image_processor(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-image-processor-{Path(tmp_dir).name}"
CustomImageProcessor.register_for_auto_class() CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR) image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("test-dynamic-image-processor", token=self._token) image_processor.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual( self.assertDictEqual(
...@@ -148,11 +152,12 @@ class ImageProcessorPushToHubTester(unittest.TestCase): ...@@ -148,11 +152,12 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"}, {"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
) )
new_image_processor = AutoImageProcessor.from_pretrained( new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
f"{USER}/test-dynamic-image-processor", trust_remote_code=True
)
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ImageProcessingUtilsTester(unittest.TestCase): class ImageProcessingUtilsTester(unittest.TestCase):
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -55,26 +55,25 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -55,26 +55,25 @@ class FlaxModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model-flax")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = FlaxBertModel(config) model = FlaxBertModel(config)
model.push_to_hub("test-model-flax", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
...@@ -82,18 +81,22 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -82,18 +81,22 @@ class FlaxModelPushToHubTester(unittest.TestCase):
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-model-flax") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
model = FlaxBertModel(config)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
...@@ -101,15 +104,21 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -101,15 +104,21 @@ class FlaxModelPushToHubTester(unittest.TestCase):
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = FlaxBertModel(config) model = FlaxBertModel(config)
model.push_to_hub("valid_org/test-model-flax-org", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
...@@ -117,20 +126,22 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -117,20 +126,22 @@ class FlaxModelPushToHubTester(unittest.TestCase):
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
try: def test_push_to_hub_in_organization_via_save_pretrained(self):
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = FlaxBertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
...@@ -138,6 +149,9 @@ class FlaxModelPushToHubTester(unittest.TestCase): ...@@ -138,6 +149,9 @@ class FlaxModelPushToHubTester(unittest.TestCase):
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def check_models_equal(model1, model2): def check_models_equal(model1, model2):
......
...@@ -23,6 +23,7 @@ import random ...@@ -23,6 +23,7 @@ import random
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from pathlib import Path
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -682,24 +683,18 @@ class TFModelPushToHubTester(unittest.TestCase): ...@@ -682,24 +683,18 @@ class TFModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model-tf")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
...@@ -710,54 +705,66 @@ class TFModelPushToHubTester(unittest.TestCase): ...@@ -710,54 +705,66 @@ class TFModelPushToHubTester(unittest.TestCase):
logging.set_verbosity_info() logging.set_verbosity_info()
logger = logging.get_logger("transformers.utils.hub") logger = logging.get_logger("transformers.utils.hub")
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
model.push_to_hub("test-model-tf", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
logging.set_verbosity_warning() logging.set_verbosity_warning()
# Check the model card was created and uploaded. # Check the model card was created and uploaded.
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out) self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf") new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): if not tf.math.reduce_all(p1 == p2):
models_equal = False models_equal = False
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-model-tf") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, token=self._token)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf") new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): if not tf.math.reduce_all(p1 == p2):
models_equal = False models_equal = False
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_push_to_hub_callback(self): def test_push_to_hub_callback(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-tf-callback-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = TFBertForMaskedLM(config) model = TFBertForMaskedLM(config)
model.compile() model.compile()
with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback( push_to_hub_callback = PushToHubCallback(
output_dir=tmp_dir, output_dir=tmp_dir,
hub_model_id="test-model-tf-callback", hub_model_id=tmp_repo,
hub_token=self._token, hub_token=self._token,
) )
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback]) model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback") new_model = TFBertForMaskedLM.from_pretrained(tmp_repo)
models_equal = True models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): if not tf.math.reduce_all(p1 == p2):
...@@ -770,8 +777,14 @@ class TFModelPushToHubTester(unittest.TestCase): ...@@ -770,8 +777,14 @@ class TFModelPushToHubTester(unittest.TestCase):
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters) pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs") pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params) self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
...@@ -779,30 +792,40 @@ class TFModelPushToHubTester(unittest.TestCase): ...@@ -779,30 +792,40 @@ class TFModelPushToHubTester(unittest.TestCase):
# Make sure model is properly initialized # Make sure model is properly initialized
model.build_in_name_scope() model.build_in_name_scope()
model.push_to_hub("valid_org/test-model-tf-org", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org") new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): if not tf.math.reduce_all(p1 == p2):
models_equal = False models_equal = False
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-tf-org")
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org") new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2): if not tf.math.reduce_all(p1 == p2):
models_equal = False models_equal = False
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
...@@ -1876,55 +1876,55 @@ class ModelPushToHubTester(unittest.TestCase): ...@@ -1876,55 +1876,55 @@ class ModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-model")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-model") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass pass
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = BertModel(config) model = BertModel(config)
model.push_to_hub("test-model", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky")
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-model") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
model = BertModel(config)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_description(self): def test_push_to_hub_with_description(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
...@@ -1937,61 +1937,84 @@ The commit description supports markdown synthax see: ...@@ -1937,61 +1937,84 @@ The commit description supports markdown synthax see:
``` ```
""" """
commit_details = model.push_to_hub( commit_details = model.push_to_hub(
"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION tmp_repo, use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
) )
self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION) self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model = BertModel(config) model = BertModel(config)
model.push_to_hub("valid_org/test-model-org", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
new_model = BertModel.from_pretrained("valid_org/test-model-org") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-model-org") config = BertConfig(
except: # noqa E722 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
pass )
model = BertModel(config)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-org")
new_model = BertModel.from_pretrained("valid_org/test-model-org") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_model(self): def test_push_to_hub_dynamic_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-model-{Path(tmp_dir).name}"
CustomConfig.register_for_auto_class() CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class() CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32) config = CustomConfig(hidden_size=32)
model = CustomModel(config) model = CustomModel(config)
model.push_to_hub("test-dynamic-model", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
# checks # checks
self.assertDictEqual( self.assertDictEqual(
config.auto_map, config.auto_map,
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"}, {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
) )
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module # Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
new_model = AutoModel.from_config(config, trust_remote_code=True) new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_tags(self): def test_push_to_hub_with_tags(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-model-with-tags-{Path(tmp_dir).name}"
from huggingface_hub import ModelCard from huggingface_hub import ModelCard
new_tags = ["tag-1", "tag-2"] new_tags = ["tag-1", "tag-2"]
...@@ -2008,10 +2031,13 @@ The commit description supports markdown synthax see: ...@@ -2008,10 +2031,13 @@ The commit description supports markdown synthax see:
self.assertTrue(model.model_tags == new_tags) self.assertTrue(model.model_tags == new_tags)
model.push_to_hub("test-dynamic-model-with-tags", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") loaded_model_card = ModelCard.load(tmp_repo)
self.assertEqual(loaded_model_card.data.tags, new_tags) self.assertEqual(loaded_model_card.data.tags, new_tags)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_torch @require_torch
......
...@@ -118,92 +118,114 @@ class TokenizerPushToHubTester(unittest.TestCase): ...@@ -118,92 +118,114 @@ class TokenizerPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-tokenizer")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub("test-tokenizer", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="test-tokenizer") vocab_file = os.path.join(tmp_dir, "vocab.txt")
except: # noqa E722 with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
pass vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try: try:
# Reset repo tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org") vocab_file = os.path.join(tmp_dir, "vocab.txt")
except: # noqa E722 with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
pass vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
tokenizer.save_pretrained(
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token
)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_tokenizers @require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self): def test_push_to_hub_dynamic_tokenizer(self):
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
CustomTokenizer.register_for_auto_class()
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file) tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer # No fast custom tokenizer
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer") self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
CustomTokenizer.register_for_auto_class()
# Fast and slow custom tokenizer # Fast and slow custom tokenizer
CustomTokenizerFast.register_for_auto_class() CustomTokenizerFast.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
...@@ -212,16 +234,17 @@ class TokenizerPushToHubTester(unittest.TestCase): ...@@ -212,16 +234,17 @@ class TokenizerPushToHubTester(unittest.TestCase):
bert_tokenizer.save_pretrained(tmp_dir) bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir) tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast") self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(tmp_repo, use_fast=False, trust_remote_code=True)
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer") self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class TrieTest(unittest.TestCase): class TrieTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment