utils.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import argparse
4
import contextlib
5
import multiprocessing
6
import time
7
import weakref
8
from collections.abc import Callable, Sequence
9
from contextlib import AbstractContextManager
10
from multiprocessing import connection
11
from multiprocessing.process import BaseProcess
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Optional,
    TypeVar,
    Union,
    overload,
)
21
22

import torch
23
from torch.autograd.profiler import record_function
24

25
import vllm.envs as envs
26
from vllm.logger import init_logger
27
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
28
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
29
from vllm.utils.system_utils import kill_process_tree
30

31
if TYPE_CHECKING:
32
33
    import numpy as np

34
    from vllm.v1.engine.coordinator import DPCoordinator
35
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
36

37
logger = init_logger(__name__)
38
39
40
41

T = TypeVar("T")


42
class ConstantList(Generic[T], Sequence):
43
    def __init__(self, x: list[T]) -> None:
44
45
46
        self._x = x

    def append(self, item):
47
        raise TypeError("Cannot append to a constant list")
48
49

    def extend(self, item):
50
        raise TypeError("Cannot extend a constant list")
51
52

    def insert(self, item):
53
        raise TypeError("Cannot insert into a constant list")
54
55

    def pop(self, item):
56
        raise TypeError("Cannot pop from a constant list")
57
58

    def remove(self, item):
59
        raise TypeError("Cannot remove from a constant list")
60
61

    def clear(self):
62
        raise TypeError("Cannot clear a constant list")
63

64
    def index(self, item: T, start: int = 0, stop: int | None = None) -> int:
65
        return self._x.index(item, start, stop if stop is not None else len(self._x))
66
67

    @overload
68
    def __getitem__(self, item: int) -> T: ...
69
70

    @overload
71
    def __getitem__(self, s: slice, /) -> list[T]: ...
72

73
    def __getitem__(self, item: int | slice) -> T | list[T]:
74
75
76
        return self._x[item]

    @overload
77
    def __setitem__(self, item: int, value: T): ...
78
79

    @overload
80
    def __setitem__(self, s: slice, value: T, /): ...
81

82
    def __setitem__(self, item: int | slice, value: T | list[T]):
83
        raise TypeError("Cannot set item in a constant list")
84
85

    def __delitem__(self, item):
86
        raise TypeError("Cannot delete item from a constant list")
87
88
89
90
91
92
93
94
95

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
96

97
98
99
    def __repr__(self):
        return f"ConstantList({self._x})"

100
101
102
    def copy(self) -> list[T]:
        return self._x.copy()

103

104
class CpuGpuBuffer:
105
    """Buffer to easily copy tensors between CPU and GPU."""
106
107
108

    def __init__(
        self,
109
        *size: int | torch.SymInt,
110
111
112
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
113
114
        with_numpy: bool = True,
    ) -> None:
115
        self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
116
        self.gpu = torch.zeros_like(self.cpu, device=device)
117
118
119
120
121
122
123
124
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
125
126
                    "numpy array, so call CpuGpuBuffer with with_numpy=False"
                )
127
            self.np = self.cpu.numpy()
128

129
    def copy_to_gpu(self, n: int | None = None) -> torch.Tensor:
130
131
132
133
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

134
    def copy_to_cpu(self, n: int | None = None) -> torch.Tensor:
135
136
137
138
139
140
141
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


142
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
143
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
144

145
146
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
147

148
149
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
150

151
152
153
154
155
    return (
        get_open_zmq_ipc_path()
        if local_only
        else (get_tcp_uri(host, port or get_open_port()))
    )
Rui Qiao's avatar
Rui Qiao committed
156
157


158
159
class APIServerProcessManager:
    """Manages a group of API server processes.
160

161
162
163
164
165
166
167
168
169
170
171
172
173
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
174
        stats_update_address: str | None = None,
175
176
    ):
        """Initialize and start API server worker processes.
177

178
179
180
181
182
183
184
185
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
186
            stats_update_address: Optional stats update address
187
188
189
190
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
191

192
193
194
195
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

196
197
198
        for i, in_addr, out_addr in zip(
            range(num_servers), input_addresses, output_addresses
        ):
199
200
201
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
202
                "client_count": num_servers,
203
                "client_index": i,
204
205
206
207
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

208
209
210
211
212
            proc = spawn_context.Process(
                target=target_server_fn,
                name=f"ApiServer_{i}",
                args=(listen_address, sock, args, client_config),
            )
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


def wait_for_completion_or_failure(
227
    api_server_manager: APIServerProcessManager,
228
229
    engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
    | None = None,
230
231
    coordinator: Optional["DPCoordinator"] = None,
) -> None:
232
    """Wait for all processes to complete or detect if any fail.
233

234
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
235
236
237
238
239
240
241

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
242
243
    """

244
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
245

246
247
248
249
250
    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
251
            proc.sentinel: proc for proc in api_server_manager.processes
252
253
254
255
256
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
257
258
259
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
260
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
261
262
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
263
264

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
265
        while sentinel_to_proc or actor_run_refs:
266
            # Wait for any process to terminate
267
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
268
269
270
271
272
273
274
275
276

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
277
278
                        f"died with exit code {proc.exitcode}"
                    )
Rui Qiao's avatar
Rui Qiao committed
279
280
281

            if actor_run_refs:
                import ray
282

Rui Qiao's avatar
Rui Qiao committed
283
284
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

285
286
287
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
288
        logger.exception("Exception occurred while running API servers: %s", str(e))
289
290
291
292
293
294
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
295
296
        if engine_manager:
            engine_manager.close()
297
298


Robert Shaw's avatar
Robert Shaw committed
299
# Note(rob): shutdown function cannot be a bound method,
300
301
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
302
    # Shutdown the process.
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
317
318
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
319

320

321
322
323
def copy_slice(
    from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
324
325
326
327
328
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
329
330

    Returns the sliced target tensor.
331
    """
332
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
333
334


335
def report_usage_stats(
336
337
    vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
) -> None:
338
339
340
341
342
343
344
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

345
346
    parallel_config = vllm_config.parallel_config

347
348
349
350
351
    # Prepare KV connector string if applicable
    kv_connector = None
    if vllm_config.kv_transfer_config is not None:
        kv_connector = vllm_config.kv_transfer_config.kv_connector

352
353
354
355
356
    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
357
358
359
360
            "dtype": str(vllm_config.model_config.dtype),
            "block_size": vllm_config.cache_config.block_size,
            "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
            "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
361
            # Quantization
362
363
            "quantization": vllm_config.model_config.quantization,
            "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
364
            # Feature flags
365
366
367
            "enable_lora": bool(vllm_config.lora_config),
            "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager": vllm_config.model_config.enforce_eager,
368
            "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
369
370
371
372
373
374
375
376
377
            # Distributed parallelism settings
            "tensor_parallel_size": parallel_config.tensor_parallel_size,
            "data_parallel_size": parallel_config.data_parallel_size,
            "pipeline_parallel_size": parallel_config.pipeline_parallel_size,
            "enable_expert_parallel": parallel_config.enable_expert_parallel,
            # All2All backend for MoE expert parallel
            "all2all_backend": parallel_config.all2all_backend,
            # KV connector used
            "kv_connector": kv_connector,
378
379
        },
    )
380
381


382
383
384
_PROFILER_FUNC = None


385
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
386
387
388
389
390
391
392
    global _PROFILER_FUNC

    # fast path assume it is set
    if _PROFILER_FUNC is not None:
        return _PROFILER_FUNC(name)

    func = contextlib.nullcontext
393
    if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
394
395
396
        func = record_function
    elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
        import nvtx
397

398
399
400
401
        func = nvtx.annotate

    _PROFILER_FUNC = func
    return func(name)
402
403
404
405
406
407
408
409
410
411
412
413
414


def tensor_data(tensor: torch.Tensor) -> memoryview:
    """Get the raw data of a tensor as a uint8 memoryview, useful for
    serializing and hashing.

    Args:
        tensor: The input tensor.

    Returns:
        A memoryview of the tensor data as uint8.
    """
    return tensor.flatten().contiguous().view(torch.uint8).numpy().data