Unverified Commit 78d13ea9 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Handle `trust_remote_code` for transformers backend (#32194)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a307ac07
...@@ -887,6 +887,7 @@ class _ModelRegistry: ...@@ -887,6 +887,7 @@ class _ModelRegistry:
module, module,
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
warn_on_fail=False, warn_on_fail=False,
) )
...@@ -899,6 +900,7 @@ class _ModelRegistry: ...@@ -899,6 +900,7 @@ class _ModelRegistry:
module, module,
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
warn_on_fail=True, warn_on_fail=True,
) )
if model_module is not None: if model_module is not None:
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.dynamic_module_utils import (
get_class_from_dynamic_module,
resolve_trust_remote_code,
)
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -13,6 +16,7 @@ logger = init_logger(__name__) ...@@ -13,6 +16,7 @@ logger = init_logger(__name__)
def try_get_class_from_dynamic_module( def try_get_class_from_dynamic_module(
class_reference: str, class_reference: str,
pretrained_model_name_or_path: str, pretrained_model_name_or_path: str,
trust_remote_code: bool,
cache_dir: str | os.PathLike | None = None, cache_dir: str | os.PathLike | None = None,
force_download: bool = False, force_download: bool = False,
resume_download: bool | None = None, resume_download: bool | None = None,
...@@ -30,6 +34,13 @@ def try_get_class_from_dynamic_module( ...@@ -30,6 +34,13 @@ def try_get_class_from_dynamic_module(
but ignoring any errors. but ignoring any errors.
""" """
try: try:
resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code=False,
has_remote_code=True,
)
return get_class_from_dynamic_module( return get_class_from_dynamic_module(
class_reference, class_reference,
pretrained_model_name_or_path, pretrained_model_name_or_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment