utils.py 15.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import argparse
4
import contextlib
5
import multiprocessing
6
import time
7
import weakref
8
from collections.abc import Callable, Sequence
9
from contextlib import AbstractContextManager
10
from dataclasses import dataclass
11
from multiprocessing import connection
12
from multiprocessing.process import BaseProcess
13
14
15
16
17
18
19
20
21
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Optional,
    TypeVar,
    Union,
    overload,
)
22
23

import torch
24
from torch.autograd.profiler import record_function
25

26
import vllm.envs as envs
27
from vllm.logger import init_logger
28
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
29
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
30
from vllm.utils.system_utils import kill_process_tree
31
from vllm.v1.core.sched.output import SchedulerOutput
32

33
if TYPE_CHECKING:
34
35
    import numpy as np

36
    from vllm.v1.engine.coordinator import DPCoordinator
37
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
38

39
logger = init_logger(__name__)
40
41
42
43

T = TypeVar("T")


44
class ConstantList(Generic[T], Sequence):
45
    def __init__(self, x: list[T]) -> None:
46
47
48
        self._x = x

    def append(self, item):
49
        raise TypeError("Cannot append to a constant list")
50
51

    def extend(self, item):
52
        raise TypeError("Cannot extend a constant list")
53
54

    def insert(self, item):
55
        raise TypeError("Cannot insert into a constant list")
56
57

    def pop(self, item):
58
        raise TypeError("Cannot pop from a constant list")
59
60

    def remove(self, item):
61
        raise TypeError("Cannot remove from a constant list")
62
63

    def clear(self):
64
        raise TypeError("Cannot clear a constant list")
65

66
    def index(self, item: T, start: int = 0, stop: int | None = None) -> int:
67
        return self._x.index(item, start, stop if stop is not None else len(self._x))
68
69

    @overload
70
    def __getitem__(self, item: int) -> T: ...
71
72

    @overload
73
    def __getitem__(self, s: slice, /) -> list[T]: ...
74

75
    def __getitem__(self, item: int | slice) -> T | list[T]:
76
77
78
        return self._x[item]

    @overload
79
    def __setitem__(self, item: int, value: T): ...
80
81

    @overload
82
    def __setitem__(self, s: slice, value: T, /): ...
83

84
    def __setitem__(self, item: int | slice, value: T | list[T]):
85
        raise TypeError("Cannot set item in a constant list")
86
87

    def __delitem__(self, item):
88
        raise TypeError("Cannot delete item from a constant list")
89
90
91
92
93
94
95
96
97

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
98

99
100
101
    def __repr__(self):
        return f"ConstantList({self._x})"

102
103
104
    def copy(self) -> list[T]:
        return self._x.copy()

105

106
class CpuGpuBuffer:
107
    """Buffer to easily copy tensors between CPU and GPU."""
108
109
110

    def __init__(
        self,
111
        *size: int | torch.SymInt,
112
113
114
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
115
116
        with_numpy: bool = True,
    ) -> None:
117
        self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
118
        self.gpu = torch.zeros_like(self.cpu, device=device)
119
120
121
122
123
124
125
126
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
127
128
                    "numpy array, so call CpuGpuBuffer with with_numpy=False"
                )
129
            self.np = self.cpu.numpy()
130

131
    def copy_to_gpu(self, n: int | None = None) -> torch.Tensor:
132
133
134
135
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

136
    def copy_to_cpu(self, n: int | None = None) -> torch.Tensor:
137
138
139
140
141
142
143
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


144
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
145
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
146

147
148
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
149

150
151
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
152

153
154
155
156
157
    return (
        get_open_zmq_ipc_path()
        if local_only
        else (get_tcp_uri(host, port or get_open_port()))
    )
Rui Qiao's avatar
Rui Qiao committed
158
159


160
161
class APIServerProcessManager:
    """Manages a group of API server processes.
162

163
164
165
166
167
168
169
170
171
172
173
174
175
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
176
        stats_update_address: str | None = None,
177
178
    ):
        """Initialize and start API server worker processes.
179

180
181
182
183
184
185
186
187
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
188
            stats_update_address: Optional stats update address
189
190
191
192
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
193

194
195
196
197
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

198
199
200
        for i, in_addr, out_addr in zip(
            range(num_servers), input_addresses, output_addresses
        ):
201
202
203
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
204
                "client_count": num_servers,
205
                "client_index": i,
206
207
208
209
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

210
211
212
213
214
            proc = spawn_context.Process(
                target=target_server_fn,
                name=f"ApiServer_{i}",
                args=(listen_address, sock, args, client_config),
            )
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


def wait_for_completion_or_failure(
229
    api_server_manager: APIServerProcessManager,
230
231
    engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
    | None = None,
232
233
    coordinator: Optional["DPCoordinator"] = None,
) -> None:
234
    """Wait for all processes to complete or detect if any fail.
235

236
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
237
238
239
240
241
242
243

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
244
245
    """

246
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
247

248
249
250
251
252
    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
253
            proc.sentinel: proc for proc in api_server_manager.processes
254
255
256
257
258
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
259
260
261
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
262
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
263
264
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
265
266

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
267
        while sentinel_to_proc or actor_run_refs:
268
            # Wait for any process to terminate
269
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
270
271
272
273
274
275
276
277
278

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
279
280
                        f"died with exit code {proc.exitcode}"
                    )
Rui Qiao's avatar
Rui Qiao committed
281
282
283

            if actor_run_refs:
                import ray
284

Rui Qiao's avatar
Rui Qiao committed
285
286
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

287
288
289
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
290
        logger.exception("Exception occurred while running API servers: %s", str(e))
291
292
293
294
295
296
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
297
298
        if engine_manager:
            engine_manager.close()
299
300


Robert Shaw's avatar
Robert Shaw committed
301
# Note(rob): shutdown function cannot be a bound method,
302
303
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
304
    # Shutdown the process.
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
319
320
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
321

322

323
324
325
def copy_slice(
    from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
326
327
328
329
330
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
331
332

    Returns the sliced target tensor.
333
    """
334
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
335
336


337
def report_usage_stats(
338
339
    vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
) -> None:
340
341
342
343
344
345
346
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

347
348
    parallel_config = vllm_config.parallel_config

349
350
351
352
353
    # Prepare KV connector string if applicable
    kv_connector = None
    if vllm_config.kv_transfer_config is not None:
        kv_connector = vllm_config.kv_transfer_config.kv_connector

354
355
356
357
358
    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
359
360
361
362
            "dtype": str(vllm_config.model_config.dtype),
            "block_size": vllm_config.cache_config.block_size,
            "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
            "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
363
            # Quantization
364
365
            "quantization": vllm_config.model_config.quantization,
            "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
366
            # Feature flags
367
368
369
            "enable_lora": bool(vllm_config.lora_config),
            "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager": vllm_config.model_config.enforce_eager,
370
            "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
371
372
373
374
375
376
377
378
379
            # Distributed parallelism settings
            "tensor_parallel_size": parallel_config.tensor_parallel_size,
            "data_parallel_size": parallel_config.data_parallel_size,
            "pipeline_parallel_size": parallel_config.pipeline_parallel_size,
            "enable_expert_parallel": parallel_config.enable_expert_parallel,
            # All2All backend for MoE expert parallel
            "all2all_backend": parallel_config.all2all_backend,
            # KV connector used
            "kv_connector": kv_connector,
380
381
        },
    )
382
383


384
385
386
_PROFILER_FUNC = None


387
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
388
389
390
391
392
393
394
    global _PROFILER_FUNC

    # fast path assume it is set
    if _PROFILER_FUNC is not None:
        return _PROFILER_FUNC(name)

    func = contextlib.nullcontext
395
    if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
396
397
398
        func = record_function
    elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
        import nvtx
399

400
401
402
403
        func = nvtx.annotate

    _PROFILER_FUNC = func
    return func(name)
404
405
406
407
408
409
410
411
412
413
414
415
416


def tensor_data(tensor: torch.Tensor) -> memoryview:
    """Get the raw data of a tensor as a uint8 memoryview, useful for
    serializing and hashing.

    Args:
        tensor: The input tensor.

    Returns:
        A memoryview of the tensor data as uint8.
    """
    return tensor.flatten().contiguous().view(torch.uint8).numpy().data
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466


@dataclass
class IterationDetails:
    num_ctx_requests: int
    num_ctx_tokens: int
    num_generation_requests: int
    num_generation_tokens: int

    def __repr__(self) -> str:
        return f"IterationDetails(num_ctx_requests={self.num_ctx_requests},\
                 num_ctx_tokens={self.num_ctx_tokens}, \
                 num_generation_requests={self.num_generation_requests}, \
                 num_generation_tokens={self.num_generation_tokens})"


def compute_iteration_details(scheduler_output: SchedulerOutput) -> IterationDetails:
    """
    Compute the number of context/generation requests and tokens
    for the current iteration's scheduler output. A requests is regarded
    as a context request if its output tokens are still 0, an extended chunk
    of chunked prefill falls into this category.

    Args:
        scheduler_output: The scheduler output for the current iteration.

    Returns:
        An IterationDetails object containing the number of
        context/generation requests and tokens.
    """
    num_context_requests = 0
    num_context_tokens = 0
    num_generation_requests = 0
    num_generation_tokens = 0
    new_req_ids = {new_req.req_id for new_req in scheduler_output.scheduled_new_reqs}
    for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items():
        if scheduler_output.scheduled_cached_reqs.is_context_phase(req_id) or (
            req_id in new_req_ids
        ):
            num_context_requests += 1
            num_context_tokens += num_tokens
        else:
            num_generation_requests += 1
            num_generation_tokens += num_tokens
    return IterationDetails(
        num_context_requests,
        num_context_tokens,
        num_generation_requests,
        num_generation_tokens,
    )