Unverified Commit de895f16 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc] improve model support check in another process (#9208)

parent cf25b93b
...@@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2 ...@@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
myst-parser==2.0.0 myst-parser==2.0.0
sphinx-argparse==0.4.0 sphinx-argparse==0.4.0
msgspec msgspec
cloudpickle
# packages to install to build the documentation # packages to install to build the documentation
pydantic >= 2.8 pydantic >= 2.8
......
import importlib import importlib
import string import pickle
import subprocess import subprocess
import sys import sys
import uuid import tempfile
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import cloudpickle
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -282,36 +283,28 @@ class ModelRegistry: ...@@ -282,36 +283,28 @@ class ModelRegistry:
raise raise
valid_name_characters = string.ascii_letters + string.digits + "._" with tempfile.NamedTemporaryFile() as output_file:
if any(s not in valid_name_characters for s in mod_name): # `cloudpickle` allows pickling lambda functions directly
raise ValueError(f"Unsafe module name detected for {model_arch}") input_bytes = cloudpickle.dumps(
if any(s not in valid_name_characters for s in cls_name): (mod_name, cls_name, func, output_file.name))
raise ValueError(f"Unsafe class name detected for {model_arch}") # cannot use `sys.executable __file__` here because the script
if any(s not in valid_name_characters for s in func.__module__): # contains relative imports
raise ValueError(f"Unsafe module name detected for {func}") returned = subprocess.run(
if any(s not in valid_name_characters for s in func.__name__): [sys.executable, "-m", "vllm.model_executor.models.registry"],
raise ValueError(f"Unsafe class name detected for {func}") input=input_bytes,
err_id = uuid.uuid4()
stmts = ";".join([
f"from {mod_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])
result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True) capture_output=True)
if result.returncode != 0: # check if the subprocess is successful
err_lines = [line.decode() for line in result.stderr.splitlines()] try:
if err_lines and err_lines[-1] != f"AssertionError: {err_id}": returned.check_returncode()
err_str = "\n".join(err_lines) except Exception as e:
raise RuntimeError( # wrap raised exception to provide more information
"An unexpected error occurred while importing the model in " raise RuntimeError(f"Error happened when testing "
f"another process. Error log:\n{err_str}") f"model support for{mod_name}.{cls_name}:\n"
f"{returned.stderr.decode()}") from e
return result.returncode == 0 with open(output_file.name, "rb") as f:
result = pickle.load(f)
return result
@staticmethod @staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
...@@ -364,3 +357,13 @@ class ModelRegistry: ...@@ -364,3 +357,13 @@ class ModelRegistry:
default=False) default=False)
return any(is_pp(arch) for arch in architectures) return any(is_pp(arch) for arch in architectures)
if __name__ == "__main__":
(mod_name, cls_name, func,
output_file) = pickle.loads(sys.stdin.buffer.read())
mod = importlib.import_module(mod_name)
klass = getattr(mod, cls_name)
result = func(klass)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment