Unverified Commit 2dda3e35 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[Bugfix] add cache model when from object storage get model (#24764)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
parent d83f3f7c
...@@ -64,6 +64,7 @@ if TYPE_CHECKING: ...@@ -64,6 +64,7 @@ if TYPE_CHECKING:
VLLM_XLA_USE_SPMD: bool = False VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
...@@ -699,6 +700,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -699,6 +700,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.path.join(get_default_cache_root(), "vllm", "assets"), os.path.join(get_default_cache_root(), "vllm", "assets"),
)), )),
# If the env var is set, we will clean model file in
# this path $VLLM_ASSETS_CACHE/model_streamer/$model_name
"VLLM_ASSETS_CACHE_MODEL_CLEAN":
lambda: bool(int(os.getenv("VLLM_ASSETS_CACHE_MODEL_CLEAN", "0"))),
# Timeout for fetching images when serving multimodal models # Timeout for fetching images when serving multimodal models
# Default is 5 seconds # Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
......
...@@ -5,9 +5,10 @@ import hashlib ...@@ -5,9 +5,10 @@ import hashlib
import os import os
import shutil import shutil
import signal import signal
import tempfile
from typing import Optional from typing import Optional
from vllm import envs
from vllm.assets.base import get_cache_dir
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
...@@ -58,20 +59,19 @@ class ObjectStorageModel: ...@@ -58,20 +59,19 @@ class ObjectStorageModel:
""" """
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
if envs.VLLM_ASSETS_CACHE_MODEL_CLEAN:
for sig in (signal.SIGINT, signal.SIGTERM): for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig) existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler)) signal.signal(sig, self._close_by_signal(existing_handler))
dir_name = os.path.join( dir_name = os.path.join(
tempfile.gettempdir(), get_cache_dir(), "model_streamer",
hashlib.sha256(str(url).encode()).hexdigest()[:8]) hashlib.sha256(str(url).encode()).hexdigest()[:8])
if os.path.exists(dir_name): if os.path.exists(dir_name):
shutil.rmtree(dir_name) shutil.rmtree(dir_name)
os.makedirs(dir_name) os.makedirs(dir_name)
self.dir = dir_name self.dir = dir_name
logger.debug("Init object storage, model cache path is: %s", dir_name)
def __del__(self):
self._close()
def _close(self) -> None: def _close(self) -> None:
if os.path.exists(self.dir): if os.path.exists(self.dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment