"tools/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d88838ff713ec67b854c420edca542d668ae159e"
Unverified Commit f1660d7e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remote code improvements (#23959)



* Fix model load when it has both code on the Hub and locally

* Add input check with timeout

* Add tests

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

* Some non-saved stuff

* Add feature extractors

* Add image processor

* Add model

* Add processor and tokenizer

* Reduce timeout

---------
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent 60825f2c
...@@ -18,6 +18,7 @@ import importlib ...@@ -18,6 +18,7 @@ import importlib
import os import os
import re import re
import shutil import shutil
import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
...@@ -513,3 +514,46 @@ def custom_object_save(obj, folder, config=None): ...@@ -513,3 +514,46 @@ def custom_object_save(obj, folder, config=None):
result.append(dest_file) result.append(dest_file)
return result return result
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute the configuration file in that repo on your local machine. We "
"asked if it was okay but did not get an answer. Make sure you have read the code there to avoid malicious "
"use, then set the option `trust_remote_code=True` to remove this error."
)
TIME_OUT_REMOTE_CODE = 15
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"Loading {model_name} requires to execute some code in that repo, you can inspect the content of "
f"the repository at https://hf.co/{model_name}. You can dismiss this prompt by passing "
"`trust_remote_code=True`.\nDo you accept? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
...@@ -18,7 +18,7 @@ import importlib ...@@ -18,7 +18,7 @@ import importlib
from collections import OrderedDict from collections import OrderedDict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import copy_func, logging, requires_backends from ...utils import copy_func, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
...@@ -404,19 +404,14 @@ class _BaseAutoModelClass: ...@@ -404,19 +404,14 @@ class _BaseAutoModelClass:
@classmethod @classmethod
def from_config(cls, config, **kwargs): def from_config(cls, config, **kwargs):
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
if not trust_remote_code: has_local_code = type(config) in cls._model_mapping.keys()
raise ValueError( trust_remote_code = resolve_trust_remote_code(
"Loading this model requires you to execute the modeling file in that repo " trust_remote_code, config._name_or_path, has_local_code, has_remote_code
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " )
"the option `trust_remote_code=True` to remove this error."
) if has_remote_code and trust_remote_code:
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
if "--" in class_ref: if "--" in class_ref:
repo_id, class_ref = class_ref.split("--") repo_id, class_ref = class_ref.split("--")
...@@ -437,7 +432,7 @@ class _BaseAutoModelClass: ...@@ -437,7 +432,7 @@ class _BaseAutoModelClass:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
hub_kwargs_names = [ hub_kwargs_names = [
"cache_dir", "cache_dir",
...@@ -470,13 +465,12 @@ class _BaseAutoModelClass: ...@@ -470,13 +465,12 @@ class _BaseAutoModelClass:
if kwargs_orig.get("torch_dtype", None) == "auto": if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto" kwargs["torch_dtype"] = "auto"
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
if not trust_remote_code: has_local_code = type(config) in cls._model_mapping.keys()
raise ValueError( trust_remote_code = resolve_trust_remote_code(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo " trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " )
"the option `trust_remote_code=True` to remove this error." if has_remote_code and trust_remote_code:
)
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
......
...@@ -20,7 +20,7 @@ from collections import OrderedDict ...@@ -20,7 +20,7 @@ from collections import OrderedDict
from typing import List, Union from typing import List, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, logging from ...utils import CONFIG_NAME, logging
...@@ -940,15 +940,15 @@ class AutoConfig: ...@@ -940,15 +940,15 @@ class AutoConfig:
```""" ```"""
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]: has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
if not trust_remote_code: has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
raise ValueError( trust_remote_code = resolve_trust_remote_code(
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that" trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then" )
" set the option `trust_remote_code=True` to remove this error."
) if has_remote_code and trust_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"] class_ref = config_dict["auto_map"]["AutoConfig"]
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None) _ = kwargs.pop("code_revision", None)
......
...@@ -21,7 +21,7 @@ from typing import Dict, Optional, Union ...@@ -21,7 +21,7 @@ from typing import Dict, Optional, Union
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
...@@ -307,7 +307,7 @@ class AutoFeatureExtractor: ...@@ -307,7 +307,7 @@ class AutoFeatureExtractor:
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/") >>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```""" ```"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
...@@ -326,21 +326,21 @@ class AutoFeatureExtractor: ...@@ -326,21 +326,21 @@ class AutoFeatureExtractor:
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"] feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
if feature_extractor_class is not None: if feature_extractor_class is not None:
# If we have custom code for a feature extractor, we get the proper class. feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
if feature_extractor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
feature_extractor_class = get_class_from_dynamic_module(
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
has_remote_code = feature_extractor_auto_map is not None
has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
feature_extractor_class = get_class_from_dynamic_module(
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif feature_extractor_class is not None:
return feature_extractor_class.from_dict(config_dict, **kwargs) return feature_extractor_class.from_dict(config_dict, **kwargs)
# Last try: we use the FEATURE_EXTRACTOR_MAPPING. # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
elif type(config) in FEATURE_EXTRACTOR_MAPPING: elif type(config) in FEATURE_EXTRACTOR_MAPPING:
......
...@@ -21,7 +21,7 @@ from typing import Dict, Optional, Union ...@@ -21,7 +21,7 @@ from typing import Dict, Optional, Union
# Build the list of all image processors # Build the list of all image processors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...image_processing_utils import ImageProcessingMixin from ...image_processing_utils import ImageProcessingMixin
from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
...@@ -314,7 +314,7 @@ class AutoImageProcessor: ...@@ -314,7 +314,7 @@ class AutoImageProcessor:
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/") >>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
```""" ```"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
...@@ -351,21 +351,21 @@ class AutoImageProcessor: ...@@ -351,21 +351,21 @@ class AutoImageProcessor:
image_processor_auto_map = config.auto_map["AutoImageProcessor"] image_processor_auto_map = config.auto_map["AutoImageProcessor"]
if image_processor_class is not None: if image_processor_class is not None:
# If we have custom code for a image processor, we get the proper class. image_processor_class = image_processor_class_from_name(image_processor_class)
if image_processor_auto_map is not None:
if not trust_remote_code: has_remote_code = image_processor_auto_map is not None
raise ValueError( has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
f"Loading {pretrained_model_name_or_path} requires you to execute the image processor file " trust_remote_code = resolve_trust_remote_code(
"in that repo on your local machine. Make sure you have read the code there to avoid " trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
"malicious use, then set the option `trust_remote_code=True` to remove this error." )
)
image_processor_class = get_class_from_dynamic_module(
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
image_processor_class = image_processor_class_from_name(image_processor_class)
if has_remote_code and trust_remote_code:
image_processor_class = get_class_from_dynamic_module(
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
return image_processor_class.from_dict(config_dict, **kwargs) return image_processor_class.from_dict(config_dict, **kwargs)
# Last try: we use the IMAGE_PROCESSOR_MAPPING. # Last try: we use the IMAGE_PROCESSOR_MAPPING.
elif type(config) in IMAGE_PROCESSOR_MAPPING: elif type(config) in IMAGE_PROCESSOR_MAPPING:
......
...@@ -20,7 +20,7 @@ from collections import OrderedDict ...@@ -20,7 +20,7 @@ from collections import OrderedDict
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin from ...image_processing_utils import ImageProcessingMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE from ...tokenization_utils import TOKENIZER_CONFIG_FILE
...@@ -194,7 +194,7 @@ class AutoProcessor: ...@@ -194,7 +194,7 @@ class AutoProcessor:
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/") >>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
```""" ```"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
processor_class = None processor_class = None
...@@ -248,28 +248,28 @@ class AutoProcessor: ...@@ -248,28 +248,28 @@ class AutoProcessor:
processor_auto_map = config.auto_map["AutoProcessor"] processor_auto_map = config.auto_map["AutoProcessor"]
if processor_class is not None: if processor_class is not None:
# If we have custom code for a feature extractor, we get the proper class. processor_class = processor_class_from_name(processor_class)
if processor_auto_map is not None:
if not trust_remote_code: has_remote_code = processor_auto_map is not None
raise ValueError( has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file " trust_remote_code = resolve_trust_remote_code(
"in that repo on your local machine. Make sure you have read the code there to avoid " trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
"malicious use, then set the option `trust_remote_code=True` to remove this error." )
)
processor_class = get_class_from_dynamic_module(
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
else:
processor_class = processor_class_from_name(processor_class)
if has_remote_code and trust_remote_code:
processor_class = get_class_from_dynamic_module(
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
elif processor_class is not None:
return processor_class.from_pretrained( return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
) )
# Last try: we use the PROCESSOR_MAPPING. # Last try: we use the PROCESSOR_MAPPING.
if type(config) in PROCESSOR_MAPPING: elif type(config) in PROCESSOR_MAPPING:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
# At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a # At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
......
...@@ -21,7 +21,7 @@ from collections import OrderedDict ...@@ -21,7 +21,7 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
...@@ -608,7 +608,7 @@ class AutoTokenizer: ...@@ -608,7 +608,7 @@ class AutoTokenizer:
use_fast = kwargs.pop("use_fast", True) use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None) tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", None)
# First, let's see whether the tokenizer_type is passed so that we can leverage it # First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None: if tokenizer_type is not None:
...@@ -662,31 +662,28 @@ class AutoTokenizer: ...@@ -662,31 +662,28 @@ class AutoTokenizer:
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"] tokenizer_auto_map = config.auto_map["AutoTokenizer"]
# If we have the tokenizer class from the tokenizer config or the model config we're good! has_remote_code = tokenizer_auto_map is not None
if config_tokenizer_class is not None: has_local_code = config_tokenizer_class is not None or type(config) in TOKENIZER_MAPPING
tokenizer_class = None trust_remote_code = resolve_trust_remote_code(
if tokenizer_auto_map is not None: trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
if not trust_remote_code: )
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use,"
" then set the option `trust_remote_code=True` to remove this error."
)
if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
elif use_fast and not config_tokenizer_class.endswith("Fast"): if has_remote_code and trust_remote_code:
if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast" tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None: if tokenizer_class is None:
tokenizer_class_candidate = config_tokenizer_class tokenizer_class_candidate = config_tokenizer_class
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None: if tokenizer_class is None:
raise ValueError( raise ValueError(
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
......
...@@ -21,6 +21,7 @@ import tempfile ...@@ -21,6 +21,7 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
import transformers
import transformers.models.auto import transformers.models.auto
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.configuration_bert import BertConfig
...@@ -37,6 +38,9 @@ SAMPLE_ROBERTA_CONFIG = get_tests_dir("fixtures/dummy-config.json") ...@@ -37,6 +38,9 @@ SAMPLE_ROBERTA_CONFIG = get_tests_dir("fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase): class AutoConfigTest(unittest.TestCase):
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
def test_module_spec(self): def test_module_spec(self):
self.assertIsNotNone(transformers.models.auto.__spec__) self.assertIsNotNone(transformers.models.auto.__spec__)
self.assertIsNotNone(importlib.util.find_spec("transformers.models.auto")) self.assertIsNotNone(importlib.util.find_spec("transformers.models.auto"))
...@@ -108,6 +112,13 @@ class AutoConfigTest(unittest.TestCase): ...@@ -108,6 +112,13 @@ class AutoConfigTest(unittest.TestCase):
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo") _ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
def test_from_pretrained_dynamic_config(self): def test_from_pretrained_dynamic_config(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model")
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False)
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig") self.assertEqual(config.__class__.__name__, "NewModelConfig")
...@@ -116,3 +127,25 @@ class AutoConfigTest(unittest.TestCase): ...@@ -116,3 +127,25 @@ class AutoConfigTest(unittest.TestCase):
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True) reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig") self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig")
def test_from_pretrained_dynamic_config_conflict(self):
class NewModelConfigLocal(BertConfig):
model_type = "new-model"
try:
AutoConfig.register("new-model", NewModelConfigLocal)
# If remote code is not set, the default is to use local
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model")
self.assertEqual(config.__class__.__name__, "NewModelConfigLocal")
# If remote code is disabled, we load the local one.
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False)
self.assertEqual(config.__class__.__name__, "NewModelConfigLocal")
# If remote is enabled, we load from the Hub
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig")
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
...@@ -19,6 +19,7 @@ import tempfile ...@@ -19,6 +19,7 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
import transformers
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING, FEATURE_EXTRACTOR_MAPPING,
...@@ -42,6 +43,9 @@ SAMPLE_CONFIG = get_tests_dir("fixtures/dummy-config.json") ...@@ -42,6 +43,9 @@ SAMPLE_CONFIG = get_tests_dir("fixtures/dummy-config.json")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
def test_feature_extractor_from_model_shortcut(self): def test_feature_extractor_from_model_shortcut(self):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
...@@ -96,6 +100,17 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -96,6 +100,17 @@ class AutoFeatureExtractorTest(unittest.TestCase):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model") _ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_feature_extractor(self): def test_from_pretrained_dynamic_feature_extractor(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor"
)
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=False
)
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
) )
...@@ -127,3 +142,37 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -127,3 +142,37 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del CONFIG_MAPPING._extra_content["custom"] del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content: if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig] del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_feature_extractor_conflict(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
is_local = True
try:
AutoConfig.register("custom", CustomConfig)
AutoFeatureExtractor.register(CustomConfig, NewFeatureExtractor)
# If remote code is not set, the default is to use local
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor"
)
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
self.assertTrue(feature_extractor.is_local)
# If remote code is disabled, we load the local one.
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=False
)
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
self.assertTrue(feature_extractor.is_local)
# If remote is enabled, we load from the Hub
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
self.assertTrue(not hasattr(feature_extractor, "is_local"))
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
...@@ -19,6 +19,7 @@ import tempfile ...@@ -19,6 +19,7 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
import transformers
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
IMAGE_PROCESSOR_MAPPING, IMAGE_PROCESSOR_MAPPING,
...@@ -37,6 +38,9 @@ from test_module.custom_image_processing import CustomImageProcessor # noqa E40 ...@@ -37,6 +38,9 @@ from test_module.custom_image_processing import CustomImageProcessor # noqa E40
class AutoImageProcessorTest(unittest.TestCase): class AutoImageProcessorTest(unittest.TestCase):
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
def test_image_processor_from_model_shortcut(self): def test_image_processor_from_model_shortcut(self):
config = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") config = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.assertIsInstance(config, CLIPImageProcessor) self.assertIsInstance(config, CLIPImageProcessor)
...@@ -130,6 +134,15 @@ class AutoImageProcessorTest(unittest.TestCase): ...@@ -130,6 +134,15 @@ class AutoImageProcessorTest(unittest.TestCase):
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model") _ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_image_processor(self): def test_from_pretrained_dynamic_image_processor(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
image_processor = AutoImageProcessor.from_pretrained("hf-internal-testing/test_dynamic_image_processor")
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=False
)
image_processor = AutoImageProcessor.from_pretrained( image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True
) )
...@@ -171,3 +184,35 @@ class AutoImageProcessorTest(unittest.TestCase): ...@@ -171,3 +184,35 @@ class AutoImageProcessorTest(unittest.TestCase):
del CONFIG_MAPPING._extra_content["custom"] del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in IMAGE_PROCESSOR_MAPPING._extra_content: if CustomConfig in IMAGE_PROCESSOR_MAPPING._extra_content:
del IMAGE_PROCESSOR_MAPPING._extra_content[CustomConfig] del IMAGE_PROCESSOR_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_image_processor_conflict(self):
class NewImageProcessor(CLIPImageProcessor):
is_local = True
try:
AutoConfig.register("custom", CustomConfig)
AutoImageProcessor.register(CustomConfig, NewImageProcessor)
# If remote code is not set, the default is to use local
image_processor = AutoImageProcessor.from_pretrained("hf-internal-testing/test_dynamic_image_processor")
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
self.assertTrue(image_processor.is_local)
# If remote code is disabled, we load the local one.
image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=False
)
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
self.assertTrue(image_processor.is_local)
# If remote is enabled, we load from the Hub
image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True
)
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
self.assertTrue(not hasattr(image_processor, "is_local"))
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in IMAGE_PROCESSOR_MAPPING._extra_content:
del IMAGE_PROCESSOR_MAPPING._extra_content[CustomConfig]
...@@ -22,6 +22,7 @@ from pathlib import Path ...@@ -22,6 +22,7 @@ from pathlib import Path
import pytest import pytest
import transformers
from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -92,6 +93,9 @@ if is_torch_available(): ...@@ -92,6 +93,9 @@ if is_torch_available():
@require_torch @require_torch
class AutoModelTest(unittest.TestCase): class AutoModelTest(unittest.TestCase):
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -312,6 +316,13 @@ class AutoModelTest(unittest.TestCase): ...@@ -312,6 +316,13 @@ class AutoModelTest(unittest.TestCase):
del MODEL_MAPPING._extra_content[CustomConfig] del MODEL_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_model_distant(self): def test_from_pretrained_dynamic_model_distant(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model")
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False)
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel") self.assertEqual(model.__class__.__name__, "NewModel")
...@@ -416,6 +427,34 @@ class AutoModelTest(unittest.TestCase): ...@@ -416,6 +427,34 @@ class AutoModelTest(unittest.TestCase):
if CustomConfig in mapping._extra_content: if CustomConfig in mapping._extra_content:
del mapping._extra_content[CustomConfig] del mapping._extra_content[CustomConfig]
def test_from_pretrained_dynamic_model_conflict(self):
class NewModelConfigLocal(BertConfig):
model_type = "new-model"
class NewModel(BertModel):
config_class = NewModelConfigLocal
try:
AutoConfig.register("new-model", NewModelConfigLocal)
AutoModel.register(NewModelConfigLocal, NewModel)
# If remote code is not set, the default is to use local
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model")
self.assertEqual(model.config.__class__.__name__, "NewModelConfigLocal")
# If remote code is disabled, we load the local one.
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=False)
self.assertEqual(model.config.__class__.__name__, "NewModelConfigLocal")
# If remote is enabled, we load from the Hub
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.config.__class__.__name__, "NewModelConfig")
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewModelConfigLocal in MODEL_MAPPING._extra_content:
del MODEL_MAPPING._extra_content[NewModelConfigLocal]
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -24,6 +24,7 @@ from shutil import copyfile ...@@ -24,6 +24,7 @@ from shutil import copyfile
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
import transformers
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING, FEATURE_EXTRACTOR_MAPPING,
...@@ -33,6 +34,8 @@ from transformers import ( ...@@ -33,6 +34,8 @@ from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
BertTokenizer,
ProcessorMixin,
Wav2Vec2Config, Wav2Vec2Config,
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
Wav2Vec2Processor, Wav2Vec2Processor,
...@@ -58,6 +61,9 @@ SAMPLE_PROCESSOR_CONFIG_DIR = get_tests_dir("fixtures") ...@@ -58,6 +61,9 @@ SAMPLE_PROCESSOR_CONFIG_DIR = get_tests_dir("fixtures")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
def test_processor_from_model_shortcut(self): def test_processor_from_model_shortcut(self):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
...@@ -144,6 +150,15 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -144,6 +150,15 @@ class AutoFeatureExtractorTest(unittest.TestCase):
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
def test_from_pretrained_dynamic_processor(self): def test_from_pretrained_dynamic_processor(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor")
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
processor = AutoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_processor", trust_remote_code=False
)
processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor", trust_remote_code=True) processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor", trust_remote_code=True)
self.assertTrue(processor.special_attribute_present) self.assertTrue(processor.special_attribute_present)
self.assertEqual(processor.__class__.__name__, "NewProcessor") self.assertEqual(processor.__class__.__name__, "NewProcessor")
...@@ -203,6 +218,58 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -203,6 +218,58 @@ class AutoFeatureExtractorTest(unittest.TestCase):
if CustomConfig in PROCESSOR_MAPPING._extra_content: if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig] del PROCESSOR_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_processor_conflict(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
special_attribute_present = False
class NewTokenizer(BertTokenizer):
special_attribute_present = False
class NewProcessor(ProcessorMixin):
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
special_attribute_present = False
try:
AutoConfig.register("custom", CustomConfig)
AutoFeatureExtractor.register(CustomConfig, NewFeatureExtractor)
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=NewTokenizer)
AutoProcessor.register(CustomConfig, NewProcessor)
# If remote code is not set, the default is to use local classes.
processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor")
self.assertEqual(processor.__class__.__name__, "NewProcessor")
self.assertFalse(processor.special_attribute_present)
self.assertFalse(processor.feature_extractor.special_attribute_present)
self.assertFalse(processor.tokenizer.special_attribute_present)
# If remote code is disabled, we load the local ones.
processor = AutoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_processor", trust_remote_code=False
)
self.assertEqual(processor.__class__.__name__, "NewProcessor")
self.assertFalse(processor.special_attribute_present)
self.assertFalse(processor.feature_extractor.special_attribute_present)
self.assertFalse(processor.tokenizer.special_attribute_present)
# If remote is enabled, we load from the Hub.
processor = AutoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_processor", trust_remote_code=True
)
self.assertEqual(processor.__class__.__name__, "NewProcessor")
self.assertTrue(processor.special_attribute_present)
self.assertTrue(processor.feature_extractor.special_attribute_present)
self.assertTrue(processor.tokenizer.special_attribute_present)
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
def test_auto_processor_creates_tokenizer(self): def test_auto_processor_creates_tokenizer(self):
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert") processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(processor.__class__.__name__, "BertTokenizerFast") self.assertEqual(processor.__class__.__name__, "BertTokenizerFast")
......
...@@ -22,6 +22,7 @@ from pathlib import Path ...@@ -22,6 +22,7 @@ from pathlib import Path
import pytest import pytest
import transformers
from transformers import ( from transformers import (
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -65,6 +66,9 @@ if is_tokenizers_available(): ...@@ -65,6 +66,9 @@ if is_tokenizers_available():
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
def setUp(self):
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
@slow @slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
for model_name in (x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x): for model_name in (x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x):
...@@ -298,6 +302,15 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -298,6 +302,15 @@ class AutoTokenizerTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig] del TOKENIZER_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_tokenizer(self): def test_from_pretrained_dynamic_tokenizer(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer")
# If remote code is disabled, we can't load this config.
with self.assertRaises(ValueError):
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=False
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True)
self.assertTrue(tokenizer.special_attribute_present) self.assertTrue(tokenizer.special_attribute_present)
# Test tokenizer can be reloaded. # Test tokenizer can be reloaded.
...@@ -326,6 +339,57 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -326,6 +339,57 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
@require_tokenizers
def test_from_pretrained_dynamic_tokenizer_conflict(self):
class NewTokenizer(BertTokenizer):
special_attribute_present = False
class NewTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = NewTokenizer
special_attribute_present = False
try:
AutoConfig.register("custom", CustomConfig)
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=NewTokenizer)
AutoTokenizer.register(CustomConfig, fast_tokenizer_class=NewTokenizerFast)
# If remote code is not set, the default is to use local
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer")
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
self.assertFalse(tokenizer.special_attribute_present)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", use_fast=False)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertFalse(tokenizer.special_attribute_present)
# If remote code is disabled, we load the local one.
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=False
)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
self.assertFalse(tokenizer.special_attribute_present)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=False, use_fast=False
)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertFalse(tokenizer.special_attribute_present)
# If remote is enabled, we load from the Hub
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True
)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
self.assertTrue(tokenizer.special_attribute_present)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, use_fast=False
)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertTrue(tokenizer.special_attribute_present)
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_tokenizer_legacy_format(self): def test_from_pretrained_dynamic_tokenizer_legacy_format(self):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True "hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment