utils.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import argparse
import multiprocessing
5
import time
6
import weakref
7
from collections.abc import Sequence
8
from multiprocessing import connection
9
10
11
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
                    Union, overload)
12
13

import torch
14
15

from vllm.logger import init_logger
16
17
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
18
19
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
                        kill_process_tree)
20

21
if TYPE_CHECKING:
22
23
    import numpy as np

24
    from vllm.v1.engine.coordinator import DPCoordinator
25
26
    from vllm.v1.engine.utils import (CoreEngineActorManager,
                                      CoreEngineProcManager)
27

28
logger = init_logger(__name__)
29
30
31
32

T = TypeVar("T")


33
class ConstantList(Generic[T], Sequence):
34

35
    def __init__(self, x: list[T]) -> None:
36
37
38
        self._x = x

    def append(self, item):
39
        raise TypeError("Cannot append to a constant list")
40
41

    def extend(self, item):
42
        raise TypeError("Cannot extend a constant list")
43
44

    def insert(self, item):
45
        raise TypeError("Cannot insert into a constant list")
46
47

    def pop(self, item):
48
        raise TypeError("Cannot pop from a constant list")
49
50

    def remove(self, item):
51
        raise TypeError("Cannot remove from a constant list")
52
53

    def clear(self):
54
        raise TypeError("Cannot clear a constant list")
55

56
57
58
59
60
61
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
62
63

    @overload
64
    def __getitem__(self, item: int) -> T:
65
66
67
        ...

    @overload
68
    def __getitem__(self, s: slice, /) -> list[T]:
69
70
        ...

71
    def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
72
73
74
        return self._x[item]

    @overload
75
    def __setitem__(self, item: int, value: T):
76
77
78
        ...

    @overload
79
    def __setitem__(self, s: slice, value: T, /):
80
81
        ...

82
    def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
83
        raise TypeError("Cannot set item in a constant list")
84
85

    def __delitem__(self, item):
86
        raise TypeError("Cannot delete item from a constant list")
87
88
89
90
91
92
93
94
95

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
96

97
98
99
    def __repr__(self):
        return f"ConstantList({self._x})"

100

101
class CpuGpuBuffer:
102
    """Buffer to easily copy tensors between CPU and GPU."""
103
104
105

    def __init__(
        self,
106
        *size: Union[int, torch.SymInt],
107
108
109
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
110
111
112
        with_numpy: bool = True,
    ) -> None:
        self.cpu = torch.zeros(*size,
113
114
115
116
                               dtype=dtype,
                               device="cpu",
                               pin_memory=pin_memory)
        self.gpu = self.cpu.to(device)
117
118
119
120
121
122
123
124
125
126
        self.np: np.ndarray
        # To keep type hints simple (avoiding generics and subclasses), we
        # only conditionally create the numpy array attribute. This can cause
        # AttributeError if `self.np` is accessed when `with_numpy=False`.
        if with_numpy:
            if dtype == torch.bfloat16:
                raise ValueError(
                    "Bfloat16 torch tensors cannot be directly cast to a "
                    "numpy array, so call CpuGpuBuffer with with_numpy=False")
            self.np = self.cpu.numpy()
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
        if n is None:
            return self.gpu.copy_(self.cpu, non_blocking=True)
        return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)

    def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor:
        """NOTE: Because this method is non-blocking, explicit synchronization
        is needed to ensure the data is copied to CPU."""
        if n is None:
            return self.cpu.copy_(self.gpu, non_blocking=True)
        return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)


141
142
143
def get_engine_client_zmq_addr(local_only: bool,
                               host: str,
                               port: int = 0) -> str:
144
    """Assign a new ZMQ socket address.
Rui Qiao's avatar
Rui Qiao committed
145

146
147
    If local_only is True, participants are colocated and so a unique IPC
    address will be returned.
Rui Qiao's avatar
Rui Qiao committed
148

149
150
    Otherwise, the provided host and port will be used to construct a TCP
    address (port == 0 means assign an available port)."""
Rui Qiao's avatar
Rui Qiao committed
151

152
153
    return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
        host, port or get_open_port()))
Rui Qiao's avatar
Rui Qiao committed
154
155


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class APIServerProcessManager:
    """Manages a group of API server processes.
    
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
        stats_update_address: Optional[str] = None,
    ):
        """Initialize and start API server worker processes.
        
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
            stats_update_address: Optional stats update address 
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args
189

190
191
192
193
194
195
196
197
198
        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

        for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
                                        output_addresses):
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
199
                "client_count": num_servers,
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                "client_index": i
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

            proc = spawn_context.Process(target=target_server_fn,
                                         name=f"ApiServer_{i}",
                                         args=(listen_address, sock, args,
                                               client_config))
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


def wait_for_completion_or_failure(
        api_server_manager: APIServerProcessManager,
224
225
        engine_manager: Optional[Union["CoreEngineProcManager",
                                       "CoreEngineActorManager"]] = None,
226
227
228
229
        coordinator: Optional["DPCoordinator"] = None) -> None:
    """Wait for all processes to complete or detect if any fail.
    
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
230
231
232
233
234
235
236

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
237
238
    """

239
240
241
    from vllm.v1.engine.utils import (CoreEngineActorManager,
                                      CoreEngineProcManager)

242
243
244
245
246
247
248
249
250
251
252
253
    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
            proc.sentinel: proc
            for proc in api_server_manager.processes
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
254
255
256
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
257
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
258
259
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
260
261

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
262
        while sentinel_to_proc or actor_run_refs:
263
            # Wait for any process to terminate
Rui Qiao's avatar
Rui Qiao committed
264
265
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
                                                         timeout=5)
266
267
268
269
270
271
272
273
274
275

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
                        f"died with exit code {proc.exitcode}")
Rui Qiao's avatar
Rui Qiao committed
276
277
278
279
280

            if actor_run_refs:
                import ray
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

281
282
283
284
285
286
287
288
289
290
291
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
        logger.exception("Exception occurred while running API servers: %s",
                         str(e))
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
292
293
        if engine_manager:
            engine_manager.close()
294
295


Robert Shaw's avatar
Robert Shaw committed
296
# Note(rob): shutdown function cannot be a bound method,
297
298
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
299
    # Shutdown the process.
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
314
315
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
316

317

318
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
319
               length: int) -> torch.Tensor:
320
321
322
323
324
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
325
326

    Returns the sliced target tensor.
327
    """
328
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
329
330


331
332
333
def report_usage_stats(
        vllm_config,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
            "dtype":
            str(vllm_config.model_config.dtype),
            "tensor_parallel_size":
            vllm_config.parallel_config.tensor_parallel_size,
            "block_size":
            vllm_config.cache_config.block_size,
            "gpu_memory_utilization":
            vllm_config.cache_config.gpu_memory_utilization,

            # Quantization
            "quantization":
            vllm_config.model_config.quantization,
            "kv_cache_dtype":
            str(vllm_config.cache_config.cache_dtype),

            # Feature flags
            "enable_lora":
            bool(vllm_config.lora_config),
            "enable_prefix_caching":
            vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager":
            vllm_config.model_config.enforce_eager,
            "disable_custom_all_reduce":
            vllm_config.parallel_config.disable_custom_all_reduce,
        })