utils.py 3.84 KB
Newer Older
Robert Shaw's avatar
Robert Shaw committed
1
import multiprocessing
2
3
import os
import weakref
4
from collections.abc import Sequence
5
6
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
                    Union, overload)
7
8

from vllm.logger import init_logger
9
from vllm.utils import get_mp_context, kill_process_tree
10
11

logger = init_logger(__name__)
12
13
14
15

T = TypeVar("T")


16
class ConstantList(Generic[T], Sequence):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

    def __init__(self, x: List[T]) -> None:
        self._x = x

    def append(self, item):
        raise Exception("Cannot append to a constant list")

    def extend(self, item):
        raise Exception("Cannot extend a constant list")

    def insert(self, item):
        raise Exception("Cannot insert into a constant list")

    def pop(self, item):
        raise Exception("Cannot pop from a constant list")

    def remove(self, item):
        raise Exception("Cannot remove from a constant list")

    def clear(self):
        raise Exception("Cannot clear a constant list")

39
40
41
42
43
44
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
45
46

    @overload
47
    def __getitem__(self, item: int) -> T:
48
49
50
51
52
53
        ...

    @overload
    def __getitem__(self, s: slice, /) -> List[T]:
        ...

54
    def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
55
56
57
        return self._x[item]

    @overload
58
    def __setitem__(self, item: int, value: T):
59
60
61
        ...

    @overload
62
    def __setitem__(self, s: slice, value: T, /):
63
64
        ...

65
    def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
66
67
68
69
70
71
72
73
74
75
76
77
78
        raise Exception("Cannot set item in a constant list")

    def __delitem__(self, item):
        raise Exception("Cannot delete item from a constant list")

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
79
80


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class BackgroundProcHandle:
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        input_path: str,
        output_path: str,
        process_name: str,
        target_fn: Callable,
        process_kwargs: Dict[Any, Any],
    ):
        context = get_mp_context()
        reader, writer = context.Pipe(duplex=False)

        assert ("ready_pipe" not in process_kwargs
                and "input_path" not in process_kwargs
                and "output_path" not in process_kwargs)
        process_kwargs["ready_pipe"] = writer
        process_kwargs["input_path"] = input_path
        process_kwargs["output_path"] = output_path

Robert Shaw's avatar
Robert Shaw committed
105
        # Run busy loop in background process.
106
        self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
Robert Shaw's avatar
Robert Shaw committed
107
108
        self._finalizer = weakref.finalize(self, shutdown, self.proc,
                                           input_path, output_path)
109
110
111
112
113
114
115
116
        self.proc.start()

        # Wait for startup.
        if reader.recv()["status"] != "READY":
            raise RuntimeError(f"{process_name} initialization failed. "
                               "See root cause above.")

    def shutdown(self):
Robert Shaw's avatar
Robert Shaw committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        self._finalizer()


# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
    # Shutdown the process.
    if proc.is_alive():
        proc.terminate()
        proc.join(5)

        if proc.is_alive():
            kill_process_tree(proc.pid)

    # Remove zmq ipc socket files.
    ipc_sockets = [output_path, input_path]
    for ipc_socket in ipc_sockets:
        socket_file = ipc_socket.replace("ipc://", "")
        if os and os.path.exists(socket_file):
            os.remove(socket_file)