Unverified Commit 80de641c authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Allow Automodel to support custom model code (#12353)

* update

* update
parent 76810eca
...@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args ...@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import logging from ..utils import logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin): ...@@ -114,6 +115,8 @@ class AutoModel(ConfigMixin):
disable_mmap ('bool', *optional*, defaults to 'False'): disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
Whether to trust remote code
<Tip> <Tip>
...@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin): ...@@ -140,22 +143,22 @@ class AutoModel(ConfigMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
load_config_kwargs = {
"cache_dir": cache_dir, hub_kwargs_names = [
"force_download": force_download, "cache_dir",
"proxies": proxies, "force_download",
"token": token, "local_files_only",
"local_files_only": local_files_only, "proxies",
"revision": revision, "resume_download",
} "revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
library = None library = None
orig_class_name = None orig_class_name = None
...@@ -189,6 +192,26 @@ class AutoModel(ConfigMixin): ...@@ -189,6 +192,26 @@ class AutoModel(ConfigMixin):
else: else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.") raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if has_remote_code and trust_remote_code:
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
model_cls = get_class_from_dynamic_module(
pretrained_model_or_path,
subfolder=subfolder,
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
else:
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates( model_cls, _ = get_class_obj_and_candidates(
......
...@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module): ...@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
def get_cached_module_file( def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str, module_file: str,
subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[Dict[str, str]] = None,
...@@ -353,6 +354,7 @@ def get_cached_module_file( ...@@ -353,6 +354,7 @@ def get_cached_module_file(
resolved_module_file = hf_hub_download( resolved_module_file = hf_hub_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
module_file, module_file,
subfolder=subfolder,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -410,6 +412,7 @@ def get_cached_module_file( ...@@ -410,6 +412,7 @@ def get_cached_module_file(
get_cached_module_file( get_cached_module_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
f"{module_needed}.py", f"{module_needed}.py",
subfolder=subfolder,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -424,6 +427,7 @@ def get_cached_module_file( ...@@ -424,6 +427,7 @@ def get_cached_module_file(
def get_class_from_dynamic_module( def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str, module_file: str,
subfolder: Optional[str] = None,
class_name: Optional[str] = None, class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
...@@ -497,6 +501,7 @@ def get_class_from_dynamic_module( ...@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
final_module = get_cached_module_file( final_module = get_cached_module_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
module_file, module_file,
subfolder=subfolder,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment