Unverified Commit 5f9b825c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use code on the Hub from another repo (#22814)

* initial work

* Add other classes

* Refactor code

* Move warning and fix dynamic pipeline

* Issue warning when necessary

* Add test

* Do not skip auto tests

* Fix failing tests

* Refactor and address review comments

* Address review comments
parent aec10d16
...@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save ...@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
PushToHubMixin, PushToHubMixin,
add_model_info_to_auto_map,
cached_file, cached_file,
copy_func, copy_func,
download_url, download_url,
...@@ -667,6 +668,10 @@ class PretrainedConfig(PushToHubMixin): ...@@ -667,6 +668,10 @@ class PretrainedConfig(PushToHubMixin):
else: else:
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
if "auto_map" in config_dict and not is_local:
config_dict["auto_map"] = add_model_info_to_auto_map(
config_dict["auto_map"], pretrained_model_name_or_path
)
return config_dict, kwargs return config_dict, kwargs
@classmethod @classmethod
......
...@@ -29,6 +29,7 @@ from .utils import ( ...@@ -29,6 +29,7 @@ from .utils import (
extract_commit_hash, extract_commit_hash,
is_offline_mode, is_offline_mode,
logging, logging,
try_to_load_from_cache,
) )
...@@ -222,11 +223,16 @@ def get_cached_module_file( ...@@ -222,11 +223,16 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1] submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
else: else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep) submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
cached_module = try_to_load_from_cache(
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
)
new_files = []
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_module_file = cached_file( resolved_module_file = cached_file(
...@@ -241,6 +247,8 @@ def get_cached_module_file( ...@@ -241,6 +247,8 @@ def get_cached_module_file(
revision=revision, revision=revision,
_commit_hash=_commit_hash, _commit_hash=_commit_hash,
) )
if not is_local and cached_module != resolved_module_file:
new_files.append(module_file)
except EnvironmentError: except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
...@@ -284,7 +292,7 @@ def get_cached_module_file( ...@@ -284,7 +292,7 @@ def get_cached_module_file(
importlib.invalidate_caches() importlib.invalidate_caches()
# Make sure we also have every file with relative # Make sure we also have every file with relative
for module_needed in modules_needed: for module_needed in modules_needed:
if not (submodule_path / module_needed).exists(): if not (submodule_path / f"{module_needed}.py").exists():
get_cached_module_file( get_cached_module_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
f"{module_needed}.py", f"{module_needed}.py",
...@@ -295,14 +303,24 @@ def get_cached_module_file( ...@@ -295,14 +303,24 @@ def get_cached_module_file(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
_commit_hash=commit_hash,
)
new_files.append(f"{module_needed}.py")
if len(new_files) > 0:
new_files = "\n".join([f"- {f}" for f in new_files])
logger.warning(
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
"versions of the code file, you can pin a revision."
) )
return os.path.join(full_submodule, module_file) return os.path.join(full_submodule, module_file)
def get_class_from_dynamic_module( def get_class_from_dynamic_module(
class_reference: str,
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
class_name: str,
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: bool = False, resume_download: bool = False,
...@@ -323,6 +341,8 @@ def get_class_from_dynamic_module( ...@@ -323,6 +341,8 @@ def get_class_from_dynamic_module(
</Tip> </Tip>
Args: Args:
class_reference (`str`):
The full name of the class to load, including its module and optionally its repo.
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either: This can be either:
...@@ -332,6 +352,7 @@ def get_class_from_dynamic_module( ...@@ -332,6 +352,7 @@ def get_class_from_dynamic_module(
- a path to a *directory* containing a configuration file saved using the - a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
This is used when `class_reference` does not specify another repo.
module_file (`str`): module_file (`str`):
The name of the module file containing the class to look for. The name of the module file containing the class to look for.
class_name (`str`): class_name (`str`):
...@@ -371,12 +392,25 @@ def get_class_from_dynamic_module( ...@@ -371,12 +392,25 @@ def get_class_from_dynamic_module(
```python ```python
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
# module. # module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
```""" ```"""
# Catch the name of the repo if it's specified in `class_reference`
if "--" in class_reference:
repo_id, class_reference = class_reference.split("--")
# Invalidate revision since it's not relevant for this repo
revision = "main"
else:
repo_id = pretrained_model_name_or_path
module_file, class_name = class_reference.split(".")
# And lastly we get the class inside our newly created module # And lastly we get the class inside our newly created module
final_module = get_cached_module_file( final_module = get_cached_module_file(
pretrained_model_name_or_path, repo_id,
module_file, module_file + ".py",
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
......
...@@ -29,6 +29,7 @@ from .utils import ( ...@@ -29,6 +29,7 @@ from .utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
PushToHubMixin, PushToHubMixin,
TensorType, TensorType,
add_model_info_to_auto_map,
cached_file, cached_file,
copy_func, copy_func,
download_url, download_url,
...@@ -469,6 +470,11 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -469,6 +470,11 @@ class FeatureExtractionMixin(PushToHubMixin):
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
) )
if "auto_map" in feature_extractor_dict and not is_local:
feature_extractor_dict["auto_map"] = add_model_info_to_auto_map(
feature_extractor_dict["auto_map"], pretrained_model_name_or_path
)
return feature_extractor_dict, kwargs return feature_extractor_dict, kwargs
@classmethod @classmethod
......
...@@ -25,6 +25,7 @@ from .feature_extraction_utils import BatchFeature as BaseBatchFeature ...@@ -25,6 +25,7 @@ from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .utils import ( from .utils import (
IMAGE_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME,
PushToHubMixin, PushToHubMixin,
add_model_info_to_auto_map,
cached_file, cached_file,
copy_func, copy_func,
download_url, download_url,
...@@ -309,6 +310,11 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -309,6 +310,11 @@ class ImageProcessingMixin(PushToHubMixin):
f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
) )
if "auto_map" in image_processor_dict and not is_local:
image_processor_dict["auto_map"] = add_model_info_to_auto_map(
image_processor_dict["auto_map"], pretrained_model_name_or_path
)
return image_processor_dict, kwargs return image_processor_dict, kwargs
@classmethod @classmethod
......
...@@ -403,8 +403,12 @@ class _BaseAutoModelClass: ...@@ -403,8 +403,12 @@ class _BaseAutoModelClass:
"no malicious code has been contributed in a newer revision." "no malicious code has been contributed in a newer revision."
) )
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
else:
repo_id = config.name_or_path
module_file, class_name = class_ref.split(".") module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs) model_class = get_class_from_dynamic_module(repo_id, module_file + ".py", class_name, **kwargs)
return model_class._from_config(config, **kwargs) return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys(): elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping) model_class = _get_model_class(config, cls._model_mapping)
...@@ -452,17 +456,10 @@ class _BaseAutoModelClass: ...@@ -452,17 +456,10 @@ class _BaseAutoModelClass:
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error." "the option `trust_remote_code=True` to remove this error."
) )
if hub_kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
) )
model_class.register_for_auto_class(cls.__name__)
return model_class.from_pretrained( return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
) )
......
...@@ -921,17 +921,8 @@ class AutoConfig: ...@@ -921,17 +921,8 @@ class AutoConfig:
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then" " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error." " set the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
"ensure no malicious code has been contributed in a newer revision."
)
class_ref = config_dict["auto_map"]["AutoConfig"] class_ref = config_dict["auto_map"]["AutoConfig"]
module_file, class_name = class_ref.split(".") config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
config_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
config_class.register_for_auto_class()
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict: elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]] config_class = CONFIG_MAPPING[config_dict["model_type"]]
......
...@@ -333,17 +333,9 @@ class AutoFeatureExtractor: ...@@ -333,17 +333,9 @@ class AutoFeatureExtractor:
"in that repo on your local machine. Make sure you have read the code there to avoid " "in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error." "malicious use, then set the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
module_file, class_name = feature_extractor_auto_map.split(".")
feature_extractor_class = get_class_from_dynamic_module( feature_extractor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
) )
feature_extractor_class.register_for_auto_class()
else: else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class) feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
......
...@@ -355,17 +355,9 @@ class AutoImageProcessor: ...@@ -355,17 +355,9 @@ class AutoImageProcessor:
"in that repo on your local machine. Make sure you have read the code there to avoid " "in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error." "malicious use, then set the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a image processor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
module_file, class_name = image_processor_auto_map.split(".")
image_processor_class = get_class_from_dynamic_module( image_processor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs image_processor_auto_map, pretrained_model_name_or_path, **kwargs
) )
image_processor_class.register_for_auto_class()
else: else:
image_processor_class = image_processor_class_from_name(image_processor_class) image_processor_class = image_processor_class_from_name(image_processor_class)
......
...@@ -254,17 +254,10 @@ class AutoProcessor: ...@@ -254,17 +254,10 @@ class AutoProcessor:
"in that repo on your local machine. Make sure you have read the code there to avoid " "in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error." "malicious use, then set the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
module_file, class_name = processor_auto_map.split(".")
processor_class = get_class_from_dynamic_module( processor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs processor_auto_map, pretrained_model_name_or_path, **kwargs
) )
processor_class.register_for_auto_class()
else: else:
processor_class = processor_class_from_name(processor_class) processor_class = processor_class_from_name(processor_class)
......
...@@ -671,22 +671,12 @@ class AutoTokenizer: ...@@ -671,22 +671,12 @@ class AutoTokenizer:
" repo on your local machine. Make sure you have read the code there to avoid malicious use," " repo on your local machine. Make sure you have read the code there to avoid malicious use,"
" then set the option `trust_remote_code=True` to remove this error." " then set the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure"
" no malicious code has been contributed in a newer revision."
)
if use_fast and tokenizer_auto_map[1] is not None: if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1] class_ref = tokenizer_auto_map[1]
else: else:
class_ref = tokenizer_auto_map[0] class_ref = tokenizer_auto_map[0]
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
module_file, class_name = class_ref.split(".")
tokenizer_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
tokenizer_class.register_for_auto_class()
elif use_fast and not config_tokenizer_class.endswith("Fast"): elif use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast" tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
......
...@@ -727,9 +727,8 @@ def pipeline( ...@@ -727,9 +727,8 @@ def pipeline(
" set the option `trust_remote_code=True` to remove this error." " set the option `trust_remote_code=True` to remove this error."
) )
class_ref = targeted_task["impl"] class_ref = targeted_task["impl"]
module_file, class_name = class_ref.split(".")
pipeline_class = get_class_from_dynamic_module( pipeline_class = get_class_from_dynamic_module(
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token class_ref, model, revision=revision, use_auth_token=use_auth_token
) )
else: else:
normalized_task, targeted_task, task_options = check_task(task) normalized_task, targeted_task, task_options = check_task(task)
......
...@@ -40,6 +40,7 @@ from .utils import ( ...@@ -40,6 +40,7 @@ from .utils import (
PushToHubMixin, PushToHubMixin,
TensorType, TensorType,
add_end_docstrings, add_end_docstrings,
add_model_info_to_auto_map,
cached_file, cached_file,
copy_func, copy_func,
download_url, download_url,
...@@ -1817,6 +1818,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1817,6 +1818,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
_commit_hash=commit_hash, _commit_hash=commit_hash,
_is_local=is_local,
**kwargs, **kwargs,
) )
...@@ -1831,6 +1833,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1831,6 +1833,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
cache_dir=None, cache_dir=None,
local_files_only=False, local_files_only=False,
_commit_hash=None, _commit_hash=None,
_is_local=False,
**kwargs, **kwargs,
): ):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
...@@ -1861,7 +1864,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1861,7 +1864,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class") config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None) init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("auto_map", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs: if not init_inputs:
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
...@@ -1869,6 +1871,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1869,6 +1871,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
config_tokenizer_class = None config_tokenizer_class = None
init_kwargs = init_configuration init_kwargs = init_configuration
if "auto_map" in init_kwargs and not _is_local:
# For backward compatibility with odl format.
if isinstance(init_kwargs["auto_map"], (tuple, list)):
init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]}
init_kwargs["auto_map"] = add_model_info_to_auto_map(
init_kwargs["auto_map"], pretrained_model_name_or_path
)
if config_tokenizer_class is None: if config_tokenizer_class is None:
from .models.auto.configuration_auto import AutoConfig # tests_ignore from .models.auto.configuration_auto import AutoConfig # tests_ignore
......
...@@ -33,6 +33,7 @@ from .generic import ( ...@@ -33,6 +33,7 @@ from .generic import (
ModelOutput, ModelOutput,
PaddingStrategy, PaddingStrategy,
TensorType, TensorType,
add_model_info_to_auto_map,
cached_property, cached_property,
can_return_loss, can_return_loss,
expand_dims, expand_dims,
...@@ -83,6 +84,7 @@ from .hub import ( ...@@ -83,6 +84,7 @@ from .hub import (
is_remote_url, is_remote_url,
move_cache, move_cache,
send_example_telemetry, send_example_telemetry,
try_to_load_from_cache,
) )
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
......
...@@ -535,3 +535,16 @@ def tensor_size(array): ...@@ -535,3 +535,16 @@ def tensor_size(array):
return array.size return array.size
else: else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
def add_model_info_to_auto_map(auto_map, repo_id):
"""
Adds the information of the repo_id to a given auto map.
"""
for key, value in auto_map.items():
if isinstance(value, (tuple, list)):
auto_map[key] = [f"{repo_id}--{v}" if "--" not in v else v for v in value]
else:
auto_map[key] = f"{repo_id}--{value}" if "--" not in value else value
return auto_map
...@@ -298,6 +298,34 @@ class AutoModelTest(unittest.TestCase): ...@@ -298,6 +298,34 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
def test_from_pretrained_dynamic_model_distant_with_ref(self):
model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained(
"hf-internal-testing/ref_to_test_dynamic_model_with_util", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_new_model_registration(self): def test_new_model_registration(self):
AutoConfig.register("custom", CustomConfig) AutoConfig.register("custom", CustomConfig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment