utils.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import argparse
4
import contextlib
5
import multiprocessing
6
import threading
7
import time
8
import weakref
9
from collections.abc import Callable, Sequence
10
from contextlib import AbstractContextManager
11
from dataclasses import dataclass
12
from multiprocessing import connection
13
from multiprocessing.process import BaseProcess
14
from multiprocessing.queues import Queue
15
16
17
18
19
20
21
22
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    TypeVar,
    Union,
    overload,
)
23
24

import torch
25
import uvloop
26
from torch.autograd.profiler import record_function
27

28
import vllm.envs as envs
29
from vllm.logger import init_logger
30
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
31
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
32
from vllm.utils.system_utils import decorate_logs, kill_process_tree, set_process_title
33
from vllm.v1.core.sched.output import SchedulerOutput
34

35
if TYPE_CHECKING:
36
37
    import numpy as np

38
    from vllm.v1.engine.coordinator import DPCoordinator
39
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
40

41
logger = init_logger(__name__)
42
43
44
45

T = TypeVar("T")


46
class ConstantList(Generic[T], Sequence):
47
    def __init__(self, x: list[T]) -> None:
48
49
50
        self._x = x

    def append(self, item):
51
        raise TypeError("Cannot append to a constant list")
52
53

    def extend(self, item):
54
        raise TypeError("Cannot extend a constant list")
55
56

    def insert(self, item):
57
        raise TypeError("Cannot insert into a constant list")
58
59

    def pop(self, item):
60
        raise TypeError("Cannot pop from a constant list")
61
62

    def remove(self, item):
63
        raise TypeError("Cannot remove from a constant list")
64
65

    def clear(self):
66
        raise TypeError("Cannot clear a constant list")
67

68
    def index(self, item: T, start: int = 0, stop: int | None = None) -> int:
69
        return self._x.index(item, start, stop if stop is not None else len(self._x))
70
71

    @overload
72
    def __getitem__(self, item: int) -> T: ...
73
74

    @overload
75
    def __getitem__(self, s: slice, /) -> list[T]: ...
76

77
    def __getitem__(self, item: int | slice) -> T | list[T]:
78
79
80
        return self._x[item]

    @overload
81
    def __setitem__(self, item: int, value: T): ...
82
83

    @overload
84
    def __setitem__(self, s: slice, value: T, /): ...
85

86
    def __setitem__(self, item: int | slice, value: T | list[T]):
87
        raise TypeError("Cannot set item in a constant list")
88
89

    def __delitem__(self, item):
90
        raise TypeError("Cannot delete item from a constant list")
91
92
93
94
95
96
97
98
99

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
100

101
102
103
    def __repr__(self):
        return f"ConstantList({self._x})"

104
105
106
    def copy(self) -> list[T]:
        return self._x.copy()

107

108
class CpuGpuBuffer:
109
    """Buffer to easily copy tensors between CPU and GPU."""
110
111
112

    def __init__(
        self,
113
        *size: int | torch.SymInt,
114
115
116
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
117
118
        with_numpy: bool = True,
    ) -> None:
119
        self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
120
        self.gpu = torch.zeros_like(self.cpu, device=device)
121
122
123
124
125
126
127
128
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
129
130
                    "numpy array, so call CpuGpuBuffer with with_numpy=False"
                )
131
            self.np = self.cpu.numpy()
132

133
    def copy_to_gpu(self, n: int | None = None) -> torch.Tensor:
134
135
136
137
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

138
    def copy_to_cpu(self, n: int | None = None) -> torch.Tensor:
139
140
141
142
143
144
145
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


146
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
147
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
148

149
150
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
151

152
153
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
154

155
156
157
158
159
    return (
        get_open_zmq_ipc_path()
        if local_only
        else (get_tcp_uri(host, port or get_open_port()))
    )
Rui Qiao's avatar
Rui Qiao committed
160
161


162
163
class APIServerProcessManager:
    """Manages a group of API server processes.
164

165
166
167
168
169
170
171
172
173
174
175
176
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
177
        target_server_fn: Callable | None = None,
178
        stats_update_address: str | None = None,
179
        tensor_queue: Queue | None = None,
180
181
    ):
        """Initialize and start API server worker processes.
182

183
        Args:
184
            target_server_fn: Override function to call for each API server process
185
186
187
188
189
190
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
191
            stats_update_address: Optional stats update address
192
            tensor_queue: Optional tensor IPC queue for sharing MM tensors
193
194
195
196
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
197

198
199
200
201
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

202
203
204
        for i, in_addr, out_addr in zip(
            range(num_servers), input_addresses, output_addresses
        ):
205
206
207
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
208
                "client_count": num_servers,
209
                "client_index": i,
210
211
212
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address
213
214
            if tensor_queue is not None:
                client_config["tensor_queue"] = tensor_queue
215

216
            proc = spawn_context.Process(
217
                target=target_server_fn or run_api_server_worker_proc,
218
219
220
                name=f"ApiServer_{i}",
                args=(listen_address, sock, args, client_config),
            )
221
222
223
224
225
226
227
228
229
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

230
231
232
233
    def shutdown(self, timeout: float | None = None) -> None:
        """Shutdown API server processes with configurable timeout"""
        if self._finalizer.detach() is not None:
            shutdown(self.processes, timeout=timeout)
234
235


236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def run_api_server_worker_proc(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
    """Entrypoint for individual API server worker processes."""

    from vllm.entrypoints.openai.api_server import run_server_worker

    client_config = client_config or {}
    server_index = client_config.get("client_index", 0)

    # Set process title and add process-specific prefix to stdout and stderr.
    set_process_title("APIServer", str(server_index))
    decorate_logs()

    uvloop.run(
        run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
    )


255
def wait_for_completion_or_failure(
256
    api_server_manager: APIServerProcessManager,
257
258
    engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
    | None = None,
259
    coordinator: "DPCoordinator | None" = None,
260
) -> None:
261
    """Wait for all processes to complete or detect if any fail.
262

263
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
264
265
266
267
268
269
270

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
271
272
273
274
275
276
277
    """

    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
278
            proc.sentinel: proc for proc in api_server_manager.processes
279
280
281
282
283
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

284
285
286
287
288
289
290
291
292
293
294
295
296
        if engine_manager:
            core_shutdown_recv, core_shutdown_send = connection.Pipe(duplex=False)

            def monitor_engines():
                try:
                    engine_manager.monitor_engine_liveness()
                finally:
                    core_shutdown_send.close()
                    core_shutdown_recv.close()

            # start monitor for engine liveness
            threading.Thread(target=monitor_engines, daemon=True).start()
            sentinel_to_proc[core_shutdown_recv] = None  # type: ignore[assignment]
297
298

        # Check if any process terminates
299
300
301
        while sentinel_to_proc:
            # Wait for any process to terminate (or engine shutdown signal)
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
302
303
304
305
306
307

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
308
                if proc is not None and proc.exitcode != 0:
309
310
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
311
312
                        f"died with exit code {proc.exitcode}"
                    )
313
314
315
316
317
                if engine_manager and engine_manager.failed_proc_name is not None:
                    raise RuntimeError(
                        f"Engine core process {engine_manager.failed_proc_name} "
                        "died unexpectedly."
                    )
Rui Qiao's avatar
Rui Qiao committed
318

319
320
321
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
322
        logger.exception("Exception occurred while running API servers: %s", str(e))
323
324
325
        raise


Robert Shaw's avatar
Robert Shaw committed
326
# Note(rob): shutdown function cannot be a bound method,
327
# else the gc cannot collect the object.
328
329
330
331
332
333
334
335
336
337
338
339
340
def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
    """Shutdown processes with timeout.

    Args:
        procs: List of processes to shutdown
        timeout: Maximum time in seconds to wait for graceful shutdown
    """
    if timeout is None:
        timeout = 0.0

    # Allow at least 5 seconds for remaining procs to terminate.
    timeout = max(timeout, 5.0)

Robert Shaw's avatar
Robert Shaw committed
341
    # Shutdown the process.
342
343
344
345
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

346
347
    # Allow time for remaining procs to terminate.
    deadline = time.monotonic() + timeout
348
349
350
351
352
353
354
355
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
356
357
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
358

359

360
361
362
def copy_slice(
    from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
363
364
365
366
367
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
368
369

    Returns the sliced target tensor.
370
    """
371
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
372
373


374
def report_usage_stats(
375
376
    vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
) -> None:
377
378
379
380
381
382
383
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

384
385
    parallel_config = vllm_config.parallel_config

386
387
388
389
390
    # Prepare KV connector string if applicable
    kv_connector = None
    if vllm_config.kv_transfer_config is not None:
        kv_connector = vllm_config.kv_transfer_config.kv_connector

391
392
393
394
395
    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
396
397
398
399
            "dtype": str(vllm_config.model_config.dtype),
            "block_size": vllm_config.cache_config.block_size,
            "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
            "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
400
            # Quantization
401
402
            "quantization": vllm_config.model_config.quantization,
            "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
403
            # Feature flags
404
405
406
            "enable_lora": bool(vllm_config.lora_config),
            "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager": vllm_config.model_config.enforce_eager,
407
            "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
408
409
410
411
412
413
414
415
416
            # Distributed parallelism settings
            "tensor_parallel_size": parallel_config.tensor_parallel_size,
            "data_parallel_size": parallel_config.data_parallel_size,
            "pipeline_parallel_size": parallel_config.pipeline_parallel_size,
            "enable_expert_parallel": parallel_config.enable_expert_parallel,
            # All2All backend for MoE expert parallel
            "all2all_backend": parallel_config.all2all_backend,
            # KV connector used
            "kv_connector": kv_connector,
417
418
        },
    )
419
420


421
422
423
_PROFILER_FUNC = None


424
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
425
426
427
428
429
430
431
    global _PROFILER_FUNC

    # fast path assume it is set
    if _PROFILER_FUNC is not None:
        return _PROFILER_FUNC(name)

    func = contextlib.nullcontext
432
    if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
433
434
435
        func = record_function
    elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
        import nvtx
436

437
438
439
440
        func = nvtx.annotate

    _PROFILER_FUNC = func
    return func(name)
441
442
443
444
445
446
447
448
449
450
451
452


def tensor_data(tensor: torch.Tensor) -> memoryview:
    """Get the raw data of a tensor as a uint8 memoryview, useful for
    serializing and hashing.

    Args:
        tensor: The input tensor.

    Returns:
        A memoryview of the tensor data as uint8.
    """
453
    return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503


@dataclass
class IterationDetails:
    num_ctx_requests: int
    num_ctx_tokens: int
    num_generation_requests: int
    num_generation_tokens: int

    def __repr__(self) -> str:
        return f"IterationDetails(num_ctx_requests={self.num_ctx_requests},\
                 num_ctx_tokens={self.num_ctx_tokens}, \
                 num_generation_requests={self.num_generation_requests}, \
                 num_generation_tokens={self.num_generation_tokens})"


def compute_iteration_details(scheduler_output: SchedulerOutput) -> IterationDetails:
    """
    Compute the number of context/generation requests and tokens
    for the current iteration's scheduler output. A requests is regarded
    as a context request if its output tokens are still 0, an extended chunk
    of chunked prefill falls into this category.

    Args:
        scheduler_output: The scheduler output for the current iteration.

    Returns:
        An IterationDetails object containing the number of
        context/generation requests and tokens.
    """
    num_context_requests = 0
    num_context_tokens = 0
    num_generation_requests = 0
    num_generation_tokens = 0
    new_req_ids = {new_req.req_id for new_req in scheduler_output.scheduled_new_reqs}
    for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items():
        if scheduler_output.scheduled_cached_reqs.is_context_phase(req_id) or (
            req_id in new_req_ids
        ):
            num_context_requests += 1
            num_context_tokens += num_tokens
        else:
            num_generation_requests += 1
            num_generation_tokens += num_tokens
    return IterationDetails(
        num_context_requests,
        num_context_tokens,
        num_generation_requests,
        num_generation_tokens,
    )