utils.py 9.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import time
5
import weakref
6
from collections import defaultdict
7
from collections.abc import Sequence
8
9
10
from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
                    overload)
11
12

import torch
13

14
from vllm.config import VllmConfig
15
from vllm.logger import init_logger
16
from vllm.model_executor.models.utils import extract_layer_index
17
18
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
19
from vllm.utils import get_mp_context, kill_process_tree
20
from vllm.v1.executor.abstract import Executor
21

22
23
24
if TYPE_CHECKING:
    from vllm.attention.layer import Attention

25
logger = init_logger(__name__)
26
27
28
29

T = TypeVar("T")


30
class ConstantList(Generic[T], Sequence):
31

32
    def __init__(self, x: list[T]) -> None:
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        self._x = x

    def append(self, item):
        raise Exception("Cannot append to a constant list")

    def extend(self, item):
        raise Exception("Cannot extend a constant list")

    def insert(self, item):
        raise Exception("Cannot insert into a constant list")

    def pop(self, item):
        raise Exception("Cannot pop from a constant list")

    def remove(self, item):
        raise Exception("Cannot remove from a constant list")

    def clear(self):
        raise Exception("Cannot clear a constant list")

53
54
55
56
57
58
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
59
60

    @overload
61
    def __getitem__(self, item: int) -> T:
62
63
64
        ...

    @overload
65
    def __getitem__(self, s: slice, /) -> list[T]:
66
67
        ...

68
    def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
69
70
71
        return self._x[item]

    @overload
72
    def __setitem__(self, item: int, value: T):
73
74
75
        ...

    @overload
76
    def __setitem__(self, s: slice, value: T, /):
77
78
        ...

79
    def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
80
81
82
83
84
85
86
87
88
89
90
91
92
        raise Exception("Cannot set item in a constant list")

    def __delitem__(self, item):
        raise Exception("Cannot delete item from a constant list")

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
93

94
95
96
    def __repr__(self):
        return f"ConstantList({self._x})"

97

98
class CoreEngineProcManager:
99
100
101
102
103
104
105
106
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        target_fn: Callable,
107
108
109
110
111
112
113
114
        local_engine_count: int,
        start_index: int,
        local_start_index: int,
        vllm_config: VllmConfig,
        on_head_node: bool,
        input_address: str,
        executor_class: type[Executor],
        log_stats: bool,
115
116
    ):
        context = get_mp_context()
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        common_kwargs = {
            "vllm_config": vllm_config,
            "on_head_node": on_head_node,
            "input_address": input_address,
            "executor_class": executor_class,
            "log_stats": log_stats,
        }

        self.processes: list[Process] = []
        for index in range(local_engine_count):
            local_index = local_start_index + index
            global_index = start_index + index
            # Start EngineCore in background process.
            self.processes.append(
                context.Process(target=target_fn,
                                name=f"EngineCore_{global_index}",
                                kwargs=common_kwargs | {
                                    "dp_rank": global_index,
                                    "local_dp_rank": local_index,
                                }))

        self._finalizer = weakref.finalize(self, shutdown, self.processes,
                                           input_address)
        try:
            for proc in self.processes:
                proc.start()
        finally:
            # Kill other procs if not all are running.
            if self.finished_procs():
                self.close()

    def close(self):
        """Shutdown all procs."""
        self._finalizer()
151

152
153
154
    def join_first(self):
        """Wait for any process to exit."""
        connection.wait(proc.sentinel for proc in self.processes)
155

156
157
    def sentinels(self) -> list:
        return [proc.sentinel for proc in self.processes]
158

159
160
161
162
163
164
    def finished_procs(self) -> dict[str, int]:
        """Returns dict of proc name -> exit code for any finished procs."""
        return {
            proc.name: proc.exitcode
            for proc in self.processes if proc.exitcode is not None
        }
Robert Shaw's avatar
Robert Shaw committed
165
166
167


# Note(rob): shutdown function cannot be a bound method,
168
169
# else the gc cannot collect the objedecoupct.
def shutdown(procs: list[Process], input_address: str):
Robert Shaw's avatar
Robert Shaw committed
170
    # Shutdown the process.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
185
186
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
187
188

    # Remove zmq ipc socket files.
189
190
    if input_address.startswith("ipc://"):
        socket_file = input_address[len("ipc://"):]
Robert Shaw's avatar
Robert Shaw committed
191
192
        if os and os.path.exists(socket_file):
            os.remove(socket_file)
193
194
195


def bind_kv_cache(
196
197
198
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: list[torch.Tensor],
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its 
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention 
        layers with layer names as keys.
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
        index2name[extract_layer_index(layer_name)].append(layer_name)

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
            raise NotImplementedError
        layer_name = layer_names[0]
        runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]
238
239
240


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
241
               length: int) -> torch.Tensor:
242
243
244
245
246
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
247
248

    Returns the sliced target tensor.
249
    """
250
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
251
252


253
254
255
def report_usage_stats(
        vllm_config,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
            "dtype":
            str(vllm_config.model_config.dtype),
            "tensor_parallel_size":
            vllm_config.parallel_config.tensor_parallel_size,
            "block_size":
            vllm_config.cache_config.block_size,
            "gpu_memory_utilization":
            vllm_config.cache_config.gpu_memory_utilization,

            # Quantization
            "quantization":
            vllm_config.model_config.quantization,
            "kv_cache_dtype":
            str(vllm_config.cache_config.cache_dtype),

            # Feature flags
            "enable_lora":
            bool(vllm_config.lora_config),
            "enable_prompt_adapter":
            bool(vllm_config.prompt_adapter_config),
            "enable_prefix_caching":
            vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager":
            vllm_config.model_config.enforce_eager,
            "disable_custom_all_reduce":
            vllm_config.parallel_config.disable_custom_all_reduce,
        })