runai_utils.py 3.29 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import hashlib
5
6
7
8
9
import os
import shutil
import signal
from typing import Optional

10
11
from vllm import envs
from vllm.assets.base import get_cache_dir
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from vllm.logger import init_logger
from vllm.utils import PlaceholderModule

logger = init_logger(__name__)

SUPPORTED_SCHEMES = ['s3://', 'gs://']

try:
    from runai_model_streamer import list_safetensors as runai_list_safetensors
    from runai_model_streamer import pull_files as runai_pull_files
except (ImportError, OSError):
    # see https://github.com/run-ai/runai-model-streamer/issues/26
    # OSError will be raised on arm64 platform
    runai_model_streamer = PlaceholderModule(
        "runai_model_streamer")  # type: ignore[assignment]
    runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
    runai_list_safetensors = runai_model_streamer.placeholder_attr(
        "list_safetensors")


def list_safetensors(path: str = "") -> list[str]:
    """
    List full file names from object path and filter by allow pattern.

    Args:
        path: The object storage path to list from.

    Returns:
        list[str]: List of full object storage paths allowed by the pattern
    """
    return runai_list_safetensors(path)


def is_runai_obj_uri(model_or_path: str) -> bool:
    return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES))


class ObjectStorageModel:
    """
    A class representing an ObjectStorage model mirrored into a
    temporary directory.

    Attributes:
        dir: The temporary created directory.

    Methods:
58
        pull_files(): Pull model from object storage to the temporary directory.
59
60
    """

61
    def __init__(self, url: str) -> None:
62
63
64
65
        if envs.VLLM_ASSETS_CACHE_MODEL_CLEAN:
            for sig in (signal.SIGINT, signal.SIGTERM):
                existing_handler = signal.getsignal(sig)
                signal.signal(sig, self._close_by_signal(existing_handler))
66

67
        dir_name = os.path.join(
68
            get_cache_dir(), "model_streamer",
69
70
71
72
73
            hashlib.sha256(str(url).encode()).hexdigest()[:8])
        if os.path.exists(dir_name):
            shutil.rmtree(dir_name)
        os.makedirs(dir_name)
        self.dir = dir_name
74
        logger.debug("Init object storage, model cache path is: %s", dir_name)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    def _close(self) -> None:
        if os.path.exists(self.dir):
            shutil.rmtree(self.dir)

    def _close_by_signal(self, existing_handler=None):

        def new_handler(signum, frame):
            self._close()
            if existing_handler:
                existing_handler(signum, frame)

        return new_handler

    def pull_files(self,
                   model_path: str = "",
                   allow_pattern: Optional[list[str]] = None,
                   ignore_pattern: Optional[list[str]] = None) -> None:
        """
        Pull files from object storage into the temporary directory.

        Args:
            model_path: The object storage path of the model.
            allow_pattern: A list of patterns of which files to pull.
            ignore_pattern: A list of patterns of which files not to pull.

        """
        if not model_path.endswith("/"):
            model_path = model_path + "/"
        runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern)