utils.py 7.84 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import os
import weakref
5
from collections import defaultdict
6
from collections.abc import Sequence
7
from multiprocessing import Process
8
9
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
                    Union, overload)
10
11

import torch
12
13

from vllm.logger import init_logger
14
from vllm.model_executor.models.utils import extract_layer_index
15
16
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
17
from vllm.utils import get_mp_context, kill_process_tree
18

19
20
21
if TYPE_CHECKING:
    from vllm.attention.layer import Attention

22
logger = init_logger(__name__)
23
24
25
26

T = TypeVar("T")


27
class ConstantList(Generic[T], Sequence):
28

29
    def __init__(self, x: list[T]) -> None:
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        self._x = x

    def append(self, item):
        raise Exception("Cannot append to a constant list")

    def extend(self, item):
        raise Exception("Cannot extend a constant list")

    def insert(self, item):
        raise Exception("Cannot insert into a constant list")

    def pop(self, item):
        raise Exception("Cannot pop from a constant list")

    def remove(self, item):
        raise Exception("Cannot remove from a constant list")

    def clear(self):
        raise Exception("Cannot clear a constant list")

50
51
52
53
54
55
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
56
57

    @overload
58
    def __getitem__(self, item: int) -> T:
59
60
61
        ...

    @overload
62
    def __getitem__(self, s: slice, /) -> list[T]:
63
64
        ...

65
    def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
66
67
68
        return self._x[item]

    @overload
69
    def __setitem__(self, item: int, value: T):
70
71
72
        ...

    @overload
73
    def __setitem__(self, s: slice, value: T, /):
74
75
        ...

76
    def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
77
78
79
80
81
82
83
84
85
86
87
88
89
        raise Exception("Cannot set item in a constant list")

    def __delitem__(self, item):
        raise Exception("Cannot delete item from a constant list")

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
90

91
92
93
    def __repr__(self):
        return f"ConstantList({self._x})"

94

95
96
97
98
99
100
101
102
103
104
105
106
class BackgroundProcHandle:
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        input_path: str,
        output_path: str,
        process_name: str,
        target_fn: Callable,
107
        process_kwargs: dict[Any, Any],
108
109
    ):
        context = get_mp_context()
110
        self.reader, writer = context.Pipe(duplex=False)
111

112
113
        assert ("ready_pipe" not in process_kwargs
                and "input_path" not in process_kwargs
114
                and "output_path" not in process_kwargs)
115
        process_kwargs["ready_pipe"] = writer
116
117
118
        process_kwargs["input_path"] = input_path
        process_kwargs["output_path"] = output_path

Robert Shaw's avatar
Robert Shaw committed
119
        # Run busy loop in background process.
120
121
122
        self.proc: Process = context.Process(target=target_fn,
                                             kwargs=process_kwargs,
                                             name=process_name)
Robert Shaw's avatar
Robert Shaw committed
123
124
        self._finalizer = weakref.finalize(self, shutdown, self.proc,
                                           input_path, output_path)
125
126
        self.proc.start()

127
128
    def fileno(self):
        return self.proc.sentinel
129

130
    def shutdown(self):
Robert Shaw's avatar
Robert Shaw committed
131
132
133
134
135
        self._finalizer()


# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
136
def shutdown(proc: Process, input_path: str, output_path: str):
Robert Shaw's avatar
Robert Shaw committed
137
138
139
140
141
    # Shutdown the process.
    if proc.is_alive():
        proc.terminate()
        proc.join(5)

142
143
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
144
145
146
147
148
149
150

    # Remove zmq ipc socket files.
    ipc_sockets = [output_path, input_path]
    for ipc_socket in ipc_sockets:
        socket_file = ipc_socket.replace("ipc://", "")
        if os and os.path.exists(socket_file):
            os.remove(socket_file)
151
152
153


def bind_kv_cache(
154
155
156
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: list[torch.Tensor],
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its 
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention 
        layers with layer names as keys.
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
        index2name[extract_layer_index(layer_name)].append(layer_name)

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
            raise NotImplementedError
        layer_name = layer_names[0]
        runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]
196
197
198


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
199
               length: int) -> torch.Tensor:
200
201
202
203
204
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
205
206

    Returns the sliced target tensor.
207
    """
208
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
209
210


211
212
213
def report_usage_stats(
        vllm_config,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
            "dtype":
            str(vllm_config.model_config.dtype),
            "tensor_parallel_size":
            vllm_config.parallel_config.tensor_parallel_size,
            "block_size":
            vllm_config.cache_config.block_size,
            "gpu_memory_utilization":
            vllm_config.cache_config.gpu_memory_utilization,

            # Quantization
            "quantization":
            vllm_config.model_config.quantization,
            "kv_cache_dtype":
            str(vllm_config.cache_config.cache_dtype),

            # Feature flags
            "enable_lora":
            bool(vllm_config.lora_config),
            "enable_prompt_adapter":
            bool(vllm_config.prompt_adapter_config),
            "enable_prefix_caching":
            vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager":
            vllm_config.model_config.enforce_eager,
            "disable_custom_all_reduce":
            vllm_config.parallel_config.disable_custom_all_reduce,
        })