Unverified Commit 91d7df58 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Copy code when using local trust remote code (#24785)

* Copy code when using local trust remote code

* Remote upgrade strategy

* Revert "Remote upgrade strategy"

This reverts commit 4f0392f5d747bcbbcf7211ef9f9b555a86778297.
parent f32303d5
......@@ -15,6 +15,7 @@
"""Factory function to build auto-model classes."""
import copy
import importlib
import os
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
......@@ -418,7 +419,10 @@ class _BaseAutoModelClass:
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls.register(config.__class__, model_class, exist_ok=True)
if os.path.isdir(config._name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
......@@ -477,7 +481,10 @@ class _BaseAutoModelClass:
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
cls.register(config.__class__, model_class, exist_ok=True)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" Auto Config class."""
import importlib
import os
import re
import warnings
from collections import OrderedDict
......@@ -984,6 +985,8 @@ class AutoConfig:
if has_remote_code and trust_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"]
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
if os.path.isdir(pretrained_model_name_or_path):
config_class.register_for_auto_class()
_ = kwargs.pop("code_revision", None)
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
......
......@@ -340,6 +340,8 @@ class AutoFeatureExtractor:
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_class.register_for_auto_class()
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif feature_extractor_class is not None:
return feature_extractor_class.from_dict(config_dict, **kwargs)
......
......@@ -364,6 +364,8 @@ class AutoImageProcessor:
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
image_processor_class.register_for_auto_class()
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
return image_processor_class.from_dict(config_dict, **kwargs)
......
......@@ -16,6 +16,7 @@
import importlib
import inspect
import json
import os
from collections import OrderedDict
# Build the list of all feature extractors
......@@ -262,6 +263,8 @@ class AutoProcessor:
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
processor_class.register_for_auto_class()
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
......
......@@ -684,6 +684,8 @@ class AutoTokenizer:
class_ref = tokenizer_auto_map[0]
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
tokenizer_class.register_for_auto_class()
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif config_tokenizer_class is not None:
tokenizer_class = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment