utils.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import argparse
4
import contextlib
5
import multiprocessing
6
import time
7
import weakref
8
from collections.abc import Sequence
9
from contextlib import AbstractContextManager
10
from multiprocessing import connection
11
12
13
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
                    Union, overload)
14
15

import torch
16
from torch.autograd.profiler import record_function
17

18
import vllm.envs as envs
19
from vllm.logger import init_logger
20
21
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
22
23
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
                        kill_process_tree)
24

25
if TYPE_CHECKING:
26
27
    import numpy as np

28
    from vllm.v1.engine.coordinator import DPCoordinator
29
30
    from vllm.v1.engine.utils import (CoreEngineActorManager,
                                      CoreEngineProcManager)
31

32
logger = init_logger(__name__)
33
34
35
36

T = TypeVar("T")


37
class ConstantList(Generic[T], Sequence):
38

39
    def __init__(self, x: list[T]) -> None:
40
41
42
        self._x = x

    def append(self, item):
43
        raise TypeError("Cannot append to a constant list")
44
45

    def extend(self, item):
46
        raise TypeError("Cannot extend a constant list")
47
48

    def insert(self, item):
49
        raise TypeError("Cannot insert into a constant list")
50
51

    def pop(self, item):
52
        raise TypeError("Cannot pop from a constant list")
53
54

    def remove(self, item):
55
        raise TypeError("Cannot remove from a constant list")
56
57

    def clear(self):
58
        raise TypeError("Cannot clear a constant list")
59

60
61
62
63
64
65
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
66
67

    @overload
68
    def __getitem__(self, item: int) -> T:
69
70
71
        ...

    @overload
72
    def __getitem__(self, s: slice, /) -> list[T]:
73
74
        ...

75
    def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
76
77
78
        return self._x[item]

    @overload
79
    def __setitem__(self, item: int, value: T):
80
81
82
        ...

    @overload
83
    def __setitem__(self, s: slice, value: T, /):
84
85
        ...

86
    def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
87
        raise TypeError("Cannot set item in a constant list")
88
89

    def __delitem__(self, item):
90
        raise TypeError("Cannot delete item from a constant list")
91
92
93
94
95
96
97
98
99

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
100

101
102
103
    def __repr__(self):
        return f"ConstantList({self._x})"

104

105
class CpuGpuBuffer:
106
    """Buffer to easily copy tensors between CPU and GPU."""
107
108
109

    def __init__(
        self,
110
        *size: Union[int, torch.SymInt],
111
112
113
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
114
115
116
        with_numpy: bool = True,
    ) -> None:
        self.cpu = torch.zeros(*size,
117
118
119
120
                               dtype=dtype,
                               device="cpu",
                               pin_memory=pin_memory)
        self.gpu = self.cpu.to(device)
121
122
123
124
125
126
127
128
129
130
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
                    "numpy array, so call CpuGpuBuffer with with_numpy=False")
            self.np = self.cpu.numpy()
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

    def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor:
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


145
146
147
def get_engine_client_zmq_addr(local_only: bool,
                               host: str,
                               port: int = 0) -> str:
148
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
149

150
151
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
152

153
154
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
155

156
157
    return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
        host, port or get_open_port()))
Rui Qiao's avatar
Rui Qiao committed
158
159


160
161
class APIServerProcessManager:
    """Manages a group of API server processes.
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
        stats_update_address: Optional[str] = None,
    ):
        """Initialize and start API server worker processes.
179

180
181
182
183
184
185
186
187
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
188
            stats_update_address: Optional stats update address
189
190
191
192
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
193

194
195
196
197
198
199
200
201
202
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

        for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
                                        output_addresses):
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
203
                "client_count": num_servers,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
                "client_index": i
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

            proc = spawn_context.Process(target=target_server_fn,
                                         name=f"ApiServer_{i}",
                                         args=(listen_address, sock, args,
                                               client_config))
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


def wait_for_completion_or_failure(
        api_server_manager: APIServerProcessManager,
228
229
        engine_manager: Optional[Union["CoreEngineProcManager",
                                       "CoreEngineActorManager"]] = None,
230
231
        coordinator: Optional["DPCoordinator"] = None) -> None:
    """Wait for all processes to complete or detect if any fail.
232

233
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
234
235
236
237
238
239
240

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
241
242
    """

243
244
245
    from vllm.v1.engine.utils import (CoreEngineActorManager,
                                      CoreEngineProcManager)

246
247
248
249
250
251
252
253
254
255
256
257
    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
            proc.sentinel: proc
            for proc in api_server_manager.processes
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
258
259
260
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
261
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
262
263
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
264
265

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
266
        while sentinel_to_proc or actor_run_refs:
267
            # Wait for any process to terminate
Rui Qiao's avatar
Rui Qiao committed
268
269
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
                                                         timeout=5)
270
271
272
273
274
275
276
277
278
279

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
                        f"died with exit code {proc.exitcode}")
Rui Qiao's avatar
Rui Qiao committed
280
281
282
283
284

            if actor_run_refs:
                import ray
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

285
286
287
288
289
290
291
292
293
294
295
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
        logger.exception("Exception occurred while running API servers: %s",
                         str(e))
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
296
297
        if engine_manager:
            engine_manager.close()
298
299


Robert Shaw's avatar
Robert Shaw committed
300
# Note(rob): shutdown function cannot be a bound method,
301
302
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
303
    # Shutdown the process.
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
318
319
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
320

321

322
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
323
               length: int) -> torch.Tensor:
324
325
326
327
328
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
329
330

    Returns the sliced target tensor.
331
    """
332
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
333
334


335
336
337
def report_usage_stats(
        vllm_config,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
            "dtype":
            str(vllm_config.model_config.dtype),
            "tensor_parallel_size":
            vllm_config.parallel_config.tensor_parallel_size,
            "block_size":
            vllm_config.cache_config.block_size,
            "gpu_memory_utilization":
            vllm_config.cache_config.gpu_memory_utilization,
358
359
            "kv_cache_memory_bytes":
            vllm_config.cache_config.kv_cache_memory_bytes,
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
            # Quantization
            "quantization":
            vllm_config.model_config.quantization,
            "kv_cache_dtype":
            str(vllm_config.cache_config.cache_dtype),

            # Feature flags
            "enable_lora":
            bool(vllm_config.lora_config),
            "enable_prefix_caching":
            vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager":
            vllm_config.model_config.enforce_eager,
            "disable_custom_all_reduce":
            vllm_config.parallel_config.disable_custom_all_reduce,
        })
376
377
378
379
380
381
382


def record_function_or_nullcontext(name: str) -> AbstractContextManager:
    if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
        return record_function(name)
    else:
        return contextlib.nullcontext()