utils.py 15.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import argparse
4
import contextlib
5
import multiprocessing
6
import time
7
import weakref
8
from collections.abc import Callable, Sequence
9
from contextlib import AbstractContextManager
10
from dataclasses import dataclass
11
from multiprocessing import connection
12
from multiprocessing.process import BaseProcess
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    TypeVar,
    Union,
    overload,
)
21
22

import torch
23
from torch.autograd.profiler import record_function
24

25
import vllm.envs as envs
26
from vllm.logger import init_logger
27
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
28
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
29
from vllm.utils.system_utils import kill_process_tree
30
from vllm.v1.core.sched.output import SchedulerOutput
31

32
if TYPE_CHECKING:
33
34
    import numpy as np

35
    from vllm.v1.engine.coordinator import DPCoordinator
36
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
37

38
logger = init_logger(__name__)
39
40
41
42

T = TypeVar("T")


43
class ConstantList(Generic[T], Sequence):
44
    def __init__(self, x: list[T]) -> None:
45
46
47
        self._x = x

    def append(self, item):
48
        raise TypeError("Cannot append to a constant list")
49
50

    def extend(self, item):
51
        raise TypeError("Cannot extend a constant list")
52
53

    def insert(self, item):
54
        raise TypeError("Cannot insert into a constant list")
55
56

    def pop(self, item):
57
        raise TypeError("Cannot pop from a constant list")
58
59

    def remove(self, item):
60
        raise TypeError("Cannot remove from a constant list")
61
62

    def clear(self):
63
        raise TypeError("Cannot clear a constant list")
64

65
    def index(self, item: T, start: int = 0, stop: int | None = None) -> int:
66
        return self._x.index(item, start, stop if stop is not None else len(self._x))
67
68

    @overload
69
    def __getitem__(self, item: int) -> T: ...
70
71

    @overload
72
    def __getitem__(self, s: slice, /) -> list[T]: ...
73

74
    def __getitem__(self, item: int | slice) -> T | list[T]:
75
76
77
        return self._x[item]

    @overload
78
    def __setitem__(self, item: int, value: T): ...
79
80

    @overload
81
    def __setitem__(self, s: slice, value: T, /): ...
82

83
    def __setitem__(self, item: int | slice, value: T | list[T]):
84
        raise TypeError("Cannot set item in a constant list")
85
86

    def __delitem__(self, item):
87
        raise TypeError("Cannot delete item from a constant list")
88
89
90
91
92
93
94
95
96

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
97

98
99
100
    def __repr__(self):
        return f"ConstantList({self._x})"

101
102
103
    def copy(self) -> list[T]:
        return self._x.copy()

104

105
class CpuGpuBuffer:
106
    """Buffer to easily copy tensors between CPU and GPU."""
107
108
109

    def __init__(
        self,
110
        *size: int | torch.SymInt,
111
112
113
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
114
115
        with_numpy: bool = True,
    ) -> None:
116
        self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
117
        self.gpu = torch.zeros_like(self.cpu, device=device)
118
119
120
121
122
123
124
125
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
126
127
                    "numpy array, so call CpuGpuBuffer with with_numpy=False"
                )
128
            self.np = self.cpu.numpy()
129

130
    def copy_to_gpu(self, n: int | None = None) -> torch.Tensor:
131
132
133
134
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

135
    def copy_to_cpu(self, n: int | None = None) -> torch.Tensor:
136
137
138
139
140
141
142
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


143
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
144
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
145

146
147
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
148

149
150
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
151

152
153
154
155
156
    return (
        get_open_zmq_ipc_path()
        if local_only
        else (get_tcp_uri(host, port or get_open_port()))
    )
Rui Qiao's avatar
Rui Qiao committed
157
158


159
160
class APIServerProcessManager:
    """Manages a group of API server processes.
161

162
163
164
165
166
167
168
169
170
171
172
173
174
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
175
        stats_update_address: str | None = None,
176
177
    ):
        """Initialize and start API server worker processes.
178

179
180
181
182
183
184
185
186
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
187
            stats_update_address: Optional stats update address
188
189
190
191
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
192

193
194
195
196
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

197
198
199
        for i, in_addr, out_addr in zip(
            range(num_servers), input_addresses, output_addresses
        ):
200
201
202
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
203
                "client_count": num_servers,
204
                "client_index": i,
205
206
207
208
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

209
210
211
212
213
            proc = spawn_context.Process(
                target=target_server_fn,
                name=f"ApiServer_{i}",
                args=(listen_address, sock, args, client_config),
            )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


def wait_for_completion_or_failure(
228
    api_server_manager: APIServerProcessManager,
229
230
    engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
    | None = None,
231
    coordinator: "DPCoordinator | None" = None,
232
) -> None:
233
    """Wait for all processes to complete or detect if any fail.
234

235
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
236
237
238
239
240
241
242

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
243
244
    """

245
    from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
246

247
248
249
250
251
    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
252
            proc.sentinel: proc for proc in api_server_manager.processes
253
254
255
256
257
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
258
259
260
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
261
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
262
263
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
264
265

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
266
        while sentinel_to_proc or actor_run_refs:
267
            # Wait for any process to terminate
268
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
269
270
271
272
273
274
275
276
277

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
278
279
                        f"died with exit code {proc.exitcode}"
                    )
Rui Qiao's avatar
Rui Qiao committed
280
281
282

            if actor_run_refs:
                import ray
283

Rui Qiao's avatar
Rui Qiao committed
284
285
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

286
287
288
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
289
        logger.exception("Exception occurred while running API servers: %s", str(e))
290
291
292
293
294
295
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
296
297
        if engine_manager:
            engine_manager.close()
298
299


Robert Shaw's avatar
Robert Shaw committed
300
# Note(rob): shutdown function cannot be a bound method,
301
302
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
303
    # Shutdown the process.
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
318
319
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
320

321

322
323
324
def copy_slice(
    from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
325
326
327
328
329
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
330
331

    Returns the sliced target tensor.
332
    """
333
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
334
335


336
def report_usage_stats(
337
338
    vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
) -> None:
339
340
341
342
343
344
345
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

346
347
    parallel_config = vllm_config.parallel_config

348
349
350
351
352
    # Prepare KV connector string if applicable
    kv_connector = None
    if vllm_config.kv_transfer_config is not None:
        kv_connector = vllm_config.kv_transfer_config.kv_connector

353
354
355
356
357
    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
358
359
360
361
            "dtype": str(vllm_config.model_config.dtype),
            "block_size": vllm_config.cache_config.block_size,
            "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
            "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
362
            # Quantization
363
364
            "quantization": vllm_config.model_config.quantization,
            "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
365
            # Feature flags
366
367
368
            "enable_lora": bool(vllm_config.lora_config),
            "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager": vllm_config.model_config.enforce_eager,
369
            "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
370
371
372
373
374
375
376
377
378
            # Distributed parallelism settings
            "tensor_parallel_size": parallel_config.tensor_parallel_size,
            "data_parallel_size": parallel_config.data_parallel_size,
            "pipeline_parallel_size": parallel_config.pipeline_parallel_size,
            "enable_expert_parallel": parallel_config.enable_expert_parallel,
            # All2All backend for MoE expert parallel
            "all2all_backend": parallel_config.all2all_backend,
            # KV connector used
            "kv_connector": kv_connector,
379
380
        },
    )
381
382


383
384
385
_PROFILER_FUNC = None


386
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
387
388
389
390
391
392
393
    global _PROFILER_FUNC

    # fast path assume it is set
    if _PROFILER_FUNC is not None:
        return _PROFILER_FUNC(name)

    func = contextlib.nullcontext
394
    if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
395
396
397
        func = record_function
    elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
        import nvtx
398

399
400
401
402
        func = nvtx.annotate

    _PROFILER_FUNC = func
    return func(name)
403
404
405
406
407
408
409
410
411
412
413
414
415


def tensor_data(tensor: torch.Tensor) -> memoryview:
    """Get the raw data of a tensor as a uint8 memoryview, useful for
    serializing and hashing.

    Args:
        tensor: The input tensor.

    Returns:
        A memoryview of the tensor data as uint8.
    """
    return tensor.flatten().contiguous().view(torch.uint8).numpy().data
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465


@dataclass
class IterationDetails:
    num_ctx_requests: int
    num_ctx_tokens: int
    num_generation_requests: int
    num_generation_tokens: int

    def __repr__(self) -> str:
        return f"IterationDetails(num_ctx_requests={self.num_ctx_requests},\
                 num_ctx_tokens={self.num_ctx_tokens}, \
                 num_generation_requests={self.num_generation_requests}, \
                 num_generation_tokens={self.num_generation_tokens})"


def compute_iteration_details(scheduler_output: SchedulerOutput) -> IterationDetails:
    """
    Compute the number of context/generation requests and tokens
    for the current iteration's scheduler output. A requests is regarded
    as a context request if its output tokens are still 0, an extended chunk
    of chunked prefill falls into this category.

    Args:
        scheduler_output: The scheduler output for the current iteration.

    Returns:
        An IterationDetails object containing the number of
        context/generation requests and tokens.
    """
    num_context_requests = 0
    num_context_tokens = 0
    num_generation_requests = 0
    num_generation_tokens = 0
    new_req_ids = {new_req.req_id for new_req in scheduler_output.scheduled_new_reqs}
    for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items():
        if scheduler_output.scheduled_cached_reqs.is_context_phase(req_id) or (
            req_id in new_req_ids
        ):
            num_context_requests += 1
            num_context_tokens += num_tokens
        else:
            num_generation_requests += 1
            num_generation_tokens += num_tokens
    return IterationDetails(
        num_context_requests,
        num_context_tokens,
        num_generation_requests,
        num_generation_tokens,
    )