"vscode:/vscode.git/clone" did not exist on "402c8b1e22b1cb161fdd9ab7fb7d9e92754ff2e6"
Unverified Commit c7f32e08 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Avoid ignored trust_remote_code warnings (#36290)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent b3546865
......@@ -24,7 +24,10 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.repo_utils import is_mistral_model_repo
from vllm.transformers_utils.utils import parse_safetensors_file_metadata
from vllm.transformers_utils.utils import (
parse_safetensors_file_metadata,
without_trust_remote_code,
)
from .config_parser_base import ConfigParserBase
from .gguf_utils import (
......@@ -140,11 +143,12 @@ class HFConfigParser(ConfigParserBase):
**kwargs,
) -> tuple[dict, PretrainedConfig]:
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
trust_remote_code |= kwargs.get("trust_remote_code", False)
kwargs = without_trust_remote_code(kwargs)
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
# Use custom model class if it's in our registry
......@@ -225,7 +229,7 @@ class MistralConfigParser(ConfigParserBase):
model,
revision=revision,
code_revision=code_revision,
**kwargs,
**without_trust_remote_code(kwargs),
)
except OSError: # Not found
hf_config_dict = {}
......@@ -521,8 +525,7 @@ def maybe_override_with_speculators(
config_dict, _ = PretrainedConfig.get_config_dict(
model if gguf_model_repo is None else gguf_model_repo,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
**without_trust_remote_code(kwargs),
)
speculators_config = config_dict.get("speculators_config")
......
......@@ -5,6 +5,8 @@ import os
from transformers import AutoConfig, DeepseekV2Config, PretrainedConfig
from vllm.transformers_utils.utils import without_trust_remote_code
class EAGLEConfig(PretrainedConfig):
model_type = "eagle"
......@@ -79,7 +81,7 @@ class EAGLEConfig(PretrainedConfig):
**kwargs,
) -> "EAGLEConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
pretrained_model_name_or_path, **without_trust_remote_code(kwargs)
)
return cls.from_dict(config_dict, **kwargs)
......
......@@ -7,6 +7,8 @@ import os
from transformers import PretrainedConfig
from vllm.transformers_utils.utils import without_trust_remote_code
class ExtractHiddenStatesConfig(PretrainedConfig):
model_type = "extract_hidden_states"
......@@ -42,7 +44,7 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
**kwargs,
) -> "ExtractHiddenStatesConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
pretrained_model_name_or_path, **without_trust_remote_code(kwargs)
)
return cls.from_dict(config_dict, **kwargs)
......
......@@ -5,6 +5,8 @@ import os
from transformers import PretrainedConfig
from vllm.transformers_utils.utils import without_trust_remote_code
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
......@@ -42,7 +44,7 @@ class MedusaConfig(PretrainedConfig):
**kwargs,
) -> "MedusaConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
pretrained_model_name_or_path, **without_trust_remote_code(kwargs)
)
for k in list(config_dict.keys()):
if "num" in k:
......
......@@ -11,6 +11,8 @@ from vllm.transformers_utils.configs.speculators.algos import (
__all__ = ["SpeculatorsConfig"]
from vllm.transformers_utils.utils import without_trust_remote_code
class SpeculatorsConfig(PretrainedConfig):
model_type = "speculators"
......@@ -22,7 +24,9 @@ class SpeculatorsConfig(PretrainedConfig):
**kwargs,
) -> "SpeculatorsConfig":
"""Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
config_dict, _ = cls.get_config_dict(
pretrained_model_name_or_path, **without_trust_remote_code(kwargs)
)
vllm_config = cls.extract_transformers_pre_trained_config(config_dict)
return cls(**vllm_config)
......
......@@ -27,6 +27,13 @@ def is_cloud_storage(model_or_path: str) -> bool:
return is_s3(model_or_path) or is_gcs(model_or_path)
def without_trust_remote_code(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Return kwargs without trust_remote_code without modifying original dict."""
if "trust_remote_code" not in kwargs:
return kwargs
return {k: v for k, v in kwargs.items() if k != "trust_remote_code"}
def modelscope_list_repo_files(
repo_id: str,
revision: str | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment