Unverified Commit b2da59b1 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Modular] Provide option to disable custom code loading globally via env variable (#12177)

* update

* update

* update

* update
parent 7aa6af11
...@@ -299,7 +299,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -299,7 +299,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: str, pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None, trust_remote_code: bool = False,
**kwargs, **kwargs,
): ):
hub_kwargs_names = [ hub_kwargs_names = [
......
...@@ -45,6 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") ...@@ -45,6 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
...@@ -20,7 +20,6 @@ import json ...@@ -20,7 +20,6 @@ import json
import os import os
import re import re
import shutil import shutil
import signal
import sys import sys
import threading import threading
from pathlib import Path from pathlib import Path
...@@ -34,6 +33,7 @@ from packaging import version ...@@ -34,6 +33,7 @@ from packaging import version
from .. import __version__ from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -159,52 +159,25 @@ def check_imports(filename): ...@@ -159,52 +159,25 @@ def check_imports(filename):
return get_relative_imports(filename) return get_relative_imports(filename)
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
if trust_remote_code is None: trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
if has_remote_code and TIME_OUT_REMOTE_CODE > 0: if DIFFUSERS_DISABLE_REMOTE_CODE:
prev_sig_handler = None logger.warning(
try: "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) )
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code: if has_remote_code and not trust_remote_code:
raise ValueError( error_msg = f"The repository for {model_name} contains custom code. "
f"Loading {model_name} requires you to execute the configuration file in that" error_msg += (
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then" "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
" set the option `trust_remote_code=True` to remove this error." if DIFFUSERS_DISABLE_REMOTE_CODE
else "Pass `trust_remote_code=True` to allow loading remote code modules."
)
raise ValueError(error_msg)
elif has_remote_code and trust_remote_code:
logger.warning(
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
) )
return trust_remote_code return trust_remote_code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment