runai_utils.py 3.08 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import hashlib
5
6
7
8
import os
import shutil
import signal

9
10
from vllm import envs
from vllm.assets.base import get_cache_dir
11
from vllm.logger import init_logger
12
from vllm.utils.import_utils import PlaceholderModule
13
14
15

logger = init_logger(__name__)

16
SUPPORTED_SCHEMES = ["s3://", "gs://"]
17
18
19
20

try:
    from runai_model_streamer import list_safetensors as runai_list_safetensors
    from runai_model_streamer import pull_files as runai_pull_files
21
except ImportError:
22
    runai_model_streamer = PlaceholderModule("runai_model_streamer")  # type: ignore[assignment]
23
    runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
24
    runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors")
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


def list_safetensors(path: str = "") -> list[str]:
    """
    List full file names from object path and filter by allow pattern.

    Args:
        path: The object storage path to list from.

    Returns:
        list[str]: List of full object storage paths allowed by the pattern
    """
    return runai_list_safetensors(path)


def is_runai_obj_uri(model_or_path: str) -> bool:
    return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES))


class ObjectStorageModel:
    """
    A class representing an ObjectStorage model mirrored into a
    temporary directory.

    Attributes:
        dir: The temporary created directory.

    Methods:
53
        pull_files(): Pull model from object storage to the temporary directory.
54
55
    """

56
    def __init__(self, url: str) -> None:
57
58
59
60
        if envs.VLLM_ASSETS_CACHE_MODEL_CLEAN:
            for sig in (signal.SIGINT, signal.SIGTERM):
                existing_handler = signal.getsignal(sig)
                signal.signal(sig, self._close_by_signal(existing_handler))
61

62
        dir_name = os.path.join(
63
64
65
66
            get_cache_dir(),
            "model_streamer",
            hashlib.sha256(str(url).encode()).hexdigest()[:8],
        )
67
        os.makedirs(dir_name, exist_ok=True)
68
        self.dir = dir_name
69
        logger.debug("Init object storage, model cache path is: %s", dir_name)
70
71
72
73
74
75
76
77
78
79
80
81
82

    def _close(self) -> None:
        if os.path.exists(self.dir):
            shutil.rmtree(self.dir)

    def _close_by_signal(self, existing_handler=None):
        def new_handler(signum, frame):
            self._close()
            if existing_handler:
                existing_handler(signum, frame)

        return new_handler

83
84
85
    def pull_files(
        self,
        model_path: str = "",
86
87
        allow_pattern: list[str] | None = None,
        ignore_pattern: list[str] | None = None,
88
    ) -> None:
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Pull files from object storage into the temporary directory.

        Args:
            model_path: The object storage path of the model.
            allow_pattern: A list of patterns of which files to pull.
            ignore_pattern: A list of patterns of which files not to pull.

        """
        if not model_path.endswith("/"):
            model_path = model_path + "/"
        runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern)