__init__.py 120 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import asyncio
7
import concurrent
8
import contextlib
9
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
10
import enum
11
import gc
12
import getpass
13
import hashlib
14
import importlib
15
import importlib.metadata
16
import importlib.util
17
import inspect
18
import ipaddress
19
import json
20
import multiprocessing
21
import os
22
import pickle
23
import signal
24
import socket
25
import subprocess
26
import sys
27
import tempfile
28
import textwrap
29
import threading
30
import time
31
import traceback
32
import types
Zhuohan Li's avatar
Zhuohan Li committed
33
import uuid
34
import warnings
35
import weakref
36
37
38
39
40
41
42
43
from argparse import (
    Action,
    ArgumentDefaultsHelpFormatter,
    ArgumentParser,
    ArgumentTypeError,
    RawDescriptionHelpFormatter,
    _ArgumentGroup,
)
44
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
45
from collections import UserDict, defaultdict
46
47
48
49
50
51
52
53
54
55
56
57
from collections.abc import (
    AsyncGenerator,
    Awaitable,
    Collection,
    Generator,
    Hashable,
    Iterable,
    Iterator,
    KeysView,
    Mapping,
    Sequence,
)
58
from concurrent.futures import ThreadPoolExecutor
59
from concurrent.futures.process import ProcessPoolExecutor
60
from dataclasses import dataclass, field
61
from functools import cache, lru_cache, partial, wraps
62
from pathlib import Path
63
from types import MappingProxyType
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generic,
    Literal,
    NamedTuple,
    Optional,
    TextIO,
    TypeVar,
    Union,
    cast,
    overload,
)
78
from urllib.parse import urlparse
79
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
80

81
import cachetools
82
import cbor2
83
import cloudpickle
84
import numpy as np
85
import numpy.typing as npt
86
import psutil
87
import regex as re
88
import setproctitle
Zhuohan Li's avatar
Zhuohan Li committed
89
import torch
90
import torch.types
91
import yaml
92
93
import zmq
import zmq.asyncio
94
from packaging import version
95
from packaging.version import Version
96
from torch.library import Library
97
from transformers.tokenization_utils_base import BatchEncoding
98
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
99

100
import vllm.envs as envs
101
from vllm.logger import enable_trace_function_call, init_logger
102
from vllm.ray.lazy_utils import is_in_ray_actor
103

104
if TYPE_CHECKING:
105
106
    from argparse import Namespace

107
    from vllm.config import ModelConfig, VllmConfig
108
    from vllm.sequence import IntermediateTensors
109

110
111
logger = init_logger(__name__)

112
113
114
115
116
117
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

133
134
135
136
137
138
MB_bytes = 1_000_000
"""The number of bytes in one megabyte (MB)."""

MiB_bytes = 1 << 20
"""The number of bytes in one mebibyte (MiB)."""

139
140
141
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""

142
143
144
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

145
# ANSI color codes
146
147
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
148

149
STR_DTYPE_TO_TORCH_DTYPE = {
150
    "float32": torch.float32,
151
152
153
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
154
    "fp8": torch.uint8,
155
156
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
157
    "int8": torch.int8,
158
    "fp8_inc": torch.float8_e4m3fn,
159
    "fp8_ds_mla": torch.uint8,
160
}
Zhuohan Li's avatar
Zhuohan Li committed
161

162
163
164
165
166
167
168
169
170
TORCH_DTYPE_TO_NUMPY_DTYPE = {
    torch.float16: np.float16,
    torch.float32: np.float32,
    torch.float64: np.float64,
    torch.uint8: np.uint8,
    torch.int32: np.int32,
    torch.int64: np.int64,
}

171
172
173
174
175
176
177
178
179
180

@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
    """Sets the default number of threads for PyTorch to the given value."""
    old_num_threads = torch.get_num_threads()
    torch.set_num_threads(num_threads)
    yield
    torch.set_num_threads(old_num_threads)


181
P = ParamSpec("P")
182
T = TypeVar("T")
183
U = TypeVar("U")
184

185
186
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
187
_T = TypeVar("_T")
188

Woosuk Kwon's avatar
Woosuk Kwon committed
189

190
class _Sentinel: ...
191
192
193
194
195


ALL_PINNED_SENTINEL = _Sentinel()


Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198
199
200
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


201
202
203
204
205
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
class Counter:
    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
210
    def __next__(self) -> int:
211
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
212
        self.counter += 1
213
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
217

218

219
220
221
222
223
224
225
226
227
228
229
230
class _MappingOrderCacheView(UserDict[_K, _V]):
    def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
        super().__init__(data)
        self.ordered_keys = ordered_keys

    def __iter__(self) -> Iterator[_K]:
        return iter(self.ordered_keys)

    def keys(self) -> KeysView[_K]:
        return KeysView(self.ordered_keys)


231
232
233
234
235
236
237
238
239
240
241
class CacheInfo(NamedTuple):
    hits: int
    total: int

    @property
    def hit_ratio(self) -> float:
        if self.total == 0:
            return 0

        return self.hits / self.total

242
243
244
245
246
247
    def __sub__(self, other: CacheInfo):
        return CacheInfo(
            hits=self.hits - other.hits,
            total=self.total - other.total,
        )

248

249
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
250
251
252
    def __init__(
        self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None
    ):
253
        super().__init__(capacity, getsizeof)
254

255
        self.pinned_items = set[_K]()
256

257
258
        self._hits = 0
        self._total = 0
259
260
261
262
263
264
265
266
267
268
        self._last_info = CacheInfo(hits=0, total=0)

    def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
        value = super().__getitem__(key)

        if update_info:
            self._hits += 1
            self._total += 1

        return value
269

270
271
    def __delitem__(self, key: _K) -> None:
        run_on_remove = key in self
272
        value = self.__getitem__(key, update_info=False)  # type: ignore[call-arg]
273
274
275
276
277
278
        super().__delitem__(key)
        if key in self.pinned_items:
            # Todo: add warning to inform that del pinned item
            self._unpin(key)
        if run_on_remove:
            self._on_remove(key, value)
279

280
281
282
283
284
    @property
    def cache(self) -> Mapping[_K, _V]:
        """Return the internal cache dictionary in order (read-only)."""
        return _MappingOrderCacheView(
            self._Cache__data,  # type: ignore
285
286
            self.order,
        )
287

288
289
290
291
    @property
    def order(self) -> Mapping[_K, None]:
        """Return the internal order dictionary (read-only)."""
        return MappingProxyType(self._LRUCache__order)  # type: ignore
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    @property
    def capacity(self) -> float:
        return self.maxsize

    @property
    def usage(self) -> float:
        if self.maxsize == 0:
            return 0

        return self.currsize / self.maxsize

    def stat(self, *, delta: bool = False) -> CacheInfo:
        """
        Gets the cumulative number of hits and queries against this cache.

308
309
        If `delta=True`, instead gets these statistics
        since the last call that also passed `delta=True`.
310
311
312
313
314
315
316
317
318
        """
        info = CacheInfo(hits=self._hits, total=self._total)

        if delta:
            info_delta = info - self._last_info
            self._last_info = info
            info = info_delta

        return info
319

320
    def touch(self, key: _K) -> None:
321
322
323
324
        try:
            self._LRUCache__order.move_to_end(key)  # type: ignore
        except KeyError:
            self._LRUCache__order[key] = None  # type: ignore
325
326

    @overload
327
    def get(self, key: _K, /) -> Optional[_V]: ...
328
329

    @overload
330
331
332
333
334
    def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...

    def get(
        self, key: _K, /, default: Optional[Union[_V, _T]] = None
    ) -> Optional[Union[_V, _T]]:
335
336
        value: Optional[Union[_V, _T]]
        if key in self:
337
            value = self.__getitem__(key, update_info=False)  # type: ignore[call-arg]
338
339

            self._hits += 1
340
        else:
341
            value = default
342
343

        self._total += 1
344
345
        return value

346
    @overload
347
    def pop(self, key: _K) -> _V: ...
348
349

    @overload
350
    def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
351

352
353
354
    def pop(
        self, key: _K, default: Optional[Union[_V, _T]] = None
    ) -> Optional[Union[_V, _T]]:
355
356
357
358
        value: Optional[Union[_V, _T]]
        if key not in self:
            return default

359
        value = self.__getitem__(key, update_info=False)  # type: ignore[call-arg]
360
        self.__delitem__(key)
361
362
        return value

363
    def put(self, key: _K, value: _V) -> None:
364
        self.__setitem__(key, value)
365

366
    def pin(self, key: _K) -> None:
367
368
369
370
        """
        Pins a key in the cache preventing it from being
        evicted in the LRU order.
        """
371
        if key not in self:
372
373
374
            raise ValueError(f"Cannot pin key: {key} not in cache.")
        self.pinned_items.add(key)

375
    def _unpin(self, key: _K) -> None:
376
377
378
379
        """
        Unpins a key in the cache allowing it to be
        evicted in the LRU order.
        """
380
381
        self.pinned_items.remove(key)

382
    def _on_remove(self, key: _K, value: Optional[_V]) -> None:
383
384
        pass

385
    def remove_oldest(self, *, remove_pinned: bool = False) -> None:
386
        if len(self) == 0:
387
            return
388

389
390
391
392
393
394
395
396
        self.popitem(remove_pinned=remove_pinned)

    def _remove_old_if_needed(self) -> None:
        while self.currsize > self.capacity:
            self.remove_oldest()

    def popitem(self, remove_pinned: bool = False):
        """Remove and return the `(key, value)` pair least recently used."""
397
398
399
        if not remove_pinned:
            # pop the oldest item in the cache that is not pinned
            lru_key = next(
400
                (key for key in self.order if key not in self.pinned_items),
401
402
                ALL_PINNED_SENTINEL,
            )
403
            if lru_key is ALL_PINNED_SENTINEL:
404
405
406
                raise RuntimeError(
                    "All items are pinned, cannot remove oldest from the cache."
                )
407
        else:
408
409
410
            lru_key = next(iter(self.order))
        value = self.pop(cast(_K, lru_key))
        return (lru_key, value)
411

412
413
414
415
416
417
418
419
    def clear(self) -> None:
        while len(self) > 0:
            self.remove_oldest(remove_pinned=True)

        self._hits = 0
        self._total = 0
        self._last_info = CacheInfo(hits=0, total=0)

420

421
class PyObjectCache:
422
    """Used to cache python objects to avoid object allocations
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    across scheduler iterations.
    """

    def __init__(self, obj_builder):
        self._obj_builder = obj_builder
        self._index = 0

        self._obj_cache = []
        for _ in range(128):
            self._obj_cache.append(self._obj_builder())

    def _grow_cache(self):
        # Double the size of the cache
        num_objs = len(self._obj_cache)
        for _ in range(num_objs):
            self._obj_cache.append(self._obj_builder())

    def get_object(self):
441
        """Returns a pre-allocated cached object. If there is not enough
442
443
444
445
446
447
448
449
450
451
452
453
        objects, then the cache size will double.
        """
        if self._index >= len(self._obj_cache):
            self._grow_cache()
            assert self._index < len(self._obj_cache)

        obj = self._obj_cache[self._index]
        self._index += 1

        return obj

    def reset(self):
454
        """Makes all cached-objects available for the next scheduler iteration."""
455
456
457
        self._index = 0


458
@cache
459
460
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
461
    from vllm import _custom_ops as ops
462
463

    max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
464
465
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
466
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
467
468
469
    return int(max_shared_mem)


470
def get_cpu_memory() -> int:
471
    """Returns the total CPU memory of the node in bytes."""
472
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
473
474
475
476


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
477

478

479
480
481
class AsyncMicrobatchTokenizer:
    """Asynchronous tokenizer with micro-batching.

482
483
    Pulls pending encode/decode requests from a queue and batches them
    up to reduce overhead. A single-thread ThreadPoolExecutor is used
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    so the event loop stays responsive.
    """

    def __init__(
        self,
        tokenizer,
        max_batch_size: int = 32,
        batch_wait_timeout_s: float = 0.002,
    ) -> None:
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
        self.batch_wait_timeout_s = batch_wait_timeout_s

        self._loop = asyncio.get_running_loop()
498
499
500
501
502
503
504
505
        self._queues: dict[
            tuple,
            asyncio.Queue[
                Union[
                    tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]
                ]
            ],
        ] = {}
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        self._batcher_tasks: list[asyncio.Task] = []

        # Single-thread executor for blocking tokenizer calls.
        self._executor = ThreadPoolExecutor(max_workers=1)

    # === Public async API ===
    async def __call__(self, prompt, **kwargs):
        result_future: asyncio.Future = self._loop.create_future()
        key = self._queue_key("encode", kwargs)
        queue = self._get_queue(self._loop, key)
        await queue.put((prompt, kwargs, result_future))
        return await result_future

    async def decode(self, token_ids, **kwargs):
        result_future: asyncio.Future = self._loop.create_future()
        key = self._queue_key("decode", kwargs)
        queue = self._get_queue(self._loop, key)
        await queue.put((token_ids, result_future))
        return await result_future

    # === Internal helpers ===
    def _get_queue(
        self, loop: asyncio.AbstractEventLoop, key: tuple
529
530
531
    ) -> asyncio.Queue[
        Union[tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]]
    ]:
532
533
534
535
536
537
538
539
540
        """Get the request queue for the given operation key, creating a new
        queue and batcher task if needed."""
        queue = self._queues.get(key)
        if queue is None:
            self._queues[key] = queue = asyncio.Queue()
            if key[0] == "encode":
                can_batch = key[1] != "other"
                coro = self._batch_encode_loop(queue, can_batch)
            else:
541
                assert key[0] == "decode", f"Unknown operation type: {key[0]}."
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
                coro = self._batch_decode_loop(queue)
            self._batcher_tasks.append(loop.create_task(coro))
        return queue

    async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
        """Batch incoming encode requests for efficiency."""
        while True:
            prompt, kwargs, result_future = await queue.get()
            prompts = [prompt]
            kwargs_list = [kwargs]
            result_futures = [result_future]
            deadline = self._loop.time() + self.batch_wait_timeout_s

            while len(prompts) < self.max_batch_size:
                timeout = deadline - self._loop.time()
                if timeout <= 0:
                    break
                try:
                    prompt, kwargs, result_future = await asyncio.wait_for(
561
562
                        queue.get(), timeout
                    )
563
564
565
566
567
568
569
570
571
572
573
                    prompts.append(prompt)
                    result_futures.append(result_future)
                    if not can_batch:
                        kwargs_list.append(kwargs)
                except asyncio.TimeoutError:
                    break

            try:
                # If every request uses identical kwargs we can run a single
                # batched tokenizer call for a big speed-up.
                if can_batch and len(prompts) > 1:
574
                    batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
575
                    results = await self._loop.run_in_executor(
576
577
                        self._executor, batch_encode_fn
                    )
578
579
580
581
582
583
584

                    for i, fut in enumerate(result_futures):
                        if not fut.done():
                            data = {k: v[i] for k, v in results.items()}
                            fut.set_result(BatchEncoding(data))
                else:
                    encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
585
                        self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
586
587
                    ]
                    results = await self._loop.run_in_executor(
588
589
                        self._executor, encode_fn
                    )
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

                    for fut, res in zip(result_futures, results):
                        if not fut.done():
                            fut.set_result(res)
            except Exception as e:
                for fut in result_futures:
                    if not fut.done():
                        fut.set_exception(e)

    async def _batch_decode_loop(self, queue: asyncio.Queue):
        """Batch incoming decode requests for efficiency."""
        while True:
            token_ids, result_future = await queue.get()
            token_ids_list = [token_ids]
            result_futures = [result_future]
            deadline = self._loop.time() + self.batch_wait_timeout_s

            while len(token_ids_list) < self.max_batch_size:
                timeout = deadline - self._loop.time()
                if timeout <= 0:
                    break
                try:
                    token_ids, result_future = await asyncio.wait_for(
613
614
                        queue.get(), timeout
                    )
615
616
617
618
619
620
621
622
                    token_ids_list.append(token_ids)
                    result_futures.append(result_future)
                except asyncio.TimeoutError:
                    break

            try:
                # Perform a single batched decode call for all requests
                results = await self._loop.run_in_executor(
623
624
                    self._executor, self.tokenizer.batch_decode, token_ids_list
                )
625
626
627
628
629
630
631
632
633
634
635
                for fut, res in zip(result_futures, results):
                    if not fut.done():
                        fut.set_result(res)
            except Exception as e:
                for fut in result_futures:
                    if not fut.done():
                        fut.set_exception(e)

    def _queue_key(self, op: str, kwargs: dict) -> tuple:
        """
        Return a normalized key describing operation + kwargs.
636

637
638
        - `add_special_tokens`: {True/False}
        - `truncation`: {True/False}
639
          - If `truncation` is False (`max_length` is None),
640
641
642
643
            returns a key for a can_batch queue.
          - If `truncation` is True and `max_length` is None or equals
            `tokenizer.model_max_length`, returns a key for a can_batch queue.
          - Otherwise, returns a key for a cannot_batch queue.
644

645
646
        Examples:
          - Decode: ("decode",)
647
          - Encode typical:
648
649
650
651
652
            ("encode", add_special_tokens, bool_truncation, max_length_label)
          - Fallback: ("encode", "other")
        """

        if op == "decode":
653
            return ("decode",)
654
655
656
657
658
659

        add_special_tokens = kwargs.get("add_special_tokens", True)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

        if not truncation:
660
            return "encode", add_special_tokens, False, None
661
662

        model_max = getattr(self.tokenizer, "model_max_length", None)
663
        if max_length is None or (model_max is not None and max_length == model_max):
664
            return "encode", add_special_tokens, True, "model_max"
665

666
        return "encode", "other"
667
668

    def __del__(self):
669
670
671
672
673
        if (
            (tasks := getattr(self, "_batcher_tasks", None))
            and (loop := getattr(self, "_loop", None))
            and not loop.is_closed()
        ):
674
675
676
677
678
679
680
681
682

            def cancel_tasks():
                for task in tasks:
                    task.cancel()

            loop.call_soon_threadsafe(cancel_tasks)


def cancel_task_threadsafe(task: Task):
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    if task and not task.done():
        run_in_loop(task.get_loop(), task.cancel)


def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]):
    for sock in sockets:
        if sock is not None:
            sock.close(linger=0)


def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
    if in_loop(loop):
        function(*args)
    elif not loop.is_closed():
        loop.call_soon_threadsafe(function, *args)


def in_loop(event_loop: AbstractEventLoop) -> bool:
    try:
        return asyncio.get_running_loop() == event_loop
    except RuntimeError:
        return False
705
706


707
def make_async(
708
    func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None
709
) -> Callable[P, Awaitable[T]]:
710
711
712
713
714
715
716
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

717
    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
718
719
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
720
        return loop.run_in_executor(executor=executor, func=p_func)
721
722
723
724

    return _async_wrapper


725
def _next_task(iterator: AsyncGenerator[T, None], loop: AbstractEventLoop) -> Task:
726
727
728
729
    # Can use anext() in python >= 3.10
    return loop.create_task(iterator.__anext__())  # type: ignore[arg-type]


730
async def merge_async_iterators(
731
732
    *iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:
733
734
735
736
737
738
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
    """
739
740
741
742
743
    if len(iterators) == 1:
        # Fast-path single iterator case.
        async for item in iterators[0]:
            yield 0, item
        return
744

745
746
747
    loop = asyncio.get_running_loop()

    awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
748
749
    try:
        while awaits:
750
            done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
751
752
753
754
755
            for d in done:
                pair = awaits.pop(d)
                try:
                    item = await d
                    i, it = pair
756
                    awaits[_next_task(it, loop)] = pair
757
758
759
760
761
762
763
764
765
                    yield i, item
                except StopAsyncIteration:
                    pass
    finally:
        # Cancel any remaining iterators
        for f, (_, it) in awaits.items():
            with contextlib.suppress(BaseException):
                f.cancel()
                await it.aclose()
766
767


768
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
769
770
771
772
773
774
775
    """Collect all items from an async generator into a list."""
    items = []
    async for item in iterator:
        items.append(item)
    return items


776
def get_ip() -> str:
777
    host_ip = envs.VLLM_HOST_IP
778
779
780
781
    if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
        logger.warning(
            "The environment variable HOST_IP is deprecated and ignored, as"
            " it is often used by Docker and other software to"
782
            " interact with the container's network stack. Please "
783
            "use VLLM_HOST_IP instead to set the IP address for vLLM processes"
784
785
            " to communicate with each other."
        )
786
787
788
789
790
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

791
    # try ipv4
792
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
793
    try:
794
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
795
        return s.getsockname()[0]
796
797
798
799
800
    except Exception:
        pass

    # try ipv6
    try:
801
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
802
803
804
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
805
        return s.getsockname()[0]
806
807
808
809
810
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
811
812
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
813
814
        stacklevel=2,
    )
815
    return "0.0.0.0"
816
817


818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
def test_loopback_bind(address, family):
    try:
        s = socket.socket(family, socket.SOCK_DGRAM)
        s.bind((address, 0))  # Port 0 = auto assign
        s.close()
        return True
    except OSError:
        return False


def get_loopback_ip() -> str:
    loopback_ip = envs.VLLM_LOOPBACK_IP
    if loopback_ip:
        return loopback_ip

    # VLLM_LOOPBACK_IP is not set, try to get it based on network interface

    if test_loopback_bind("127.0.0.1", socket.AF_INET):
        return "127.0.0.1"
    elif test_loopback_bind("::1", socket.AF_INET6):
        return "::1"
    else:
        raise RuntimeError(
            "Neither 127.0.0.1 nor ::1 are bound to a local interface. "
842
843
            "Set the VLLM_LOOPBACK_IP environment variable explicitly."
        )
844
845


846
847
848
849
850
851
852
853
def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False


854
def split_host_port(host_port: str) -> tuple[str, int]:
855
    # ipv6
856
857
    if host_port.startswith("["):
        host, port = host_port.rsplit("]", 1)
858
        host = host[1:]
859
        port = port.split(":")[1]
860
861
        return host, int(port)
    else:
862
        host, port = host_port.split(":")
863
864
865
866
867
868
869
870
871
872
        return host, int(port)


def join_host_port(host: str, port: int) -> str:
    if is_valid_ipv6_address(host):
        return f"[{host}]:{port}"
    else:
        return f"{host}:{port}"


873
def get_distributed_init_method(ip: str, port: int) -> str:
874
875
876
877
    return get_tcp_uri(ip, port)


def get_tcp_uri(ip: str, port: int) -> str:
878
879
880
881
    if is_valid_ipv6_address(ip):
        return f"tcp://[{ip}]:{port}"
    else:
        return f"tcp://{ip}:{port}"
882
883


884
885
886
887
888
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


889
890
891
892
def get_open_zmq_inproc_path() -> str:
    return f"inproc://{uuid4()}"


893
def get_open_port() -> int:
894
895
896
897
898
899
900
901
902
    """
    Get an open port for the vLLM process to listen on.
    An edge case to handle, is when we run data parallel,
    we need to avoid ports that are potentially used by
    the data parallel master process.
    Right now we reserve 10 ports for the data parallel master
    process. Currently it uses 2 ports.
    """
    if "VLLM_DP_MASTER_PORT" in os.environ:
903
904
        dp_master_port = envs.VLLM_DP_MASTER_PORT
        reserved_port_range = range(dp_master_port, dp_master_port + 10)
905
        while True:
906
907
908
            candidate_port = _get_open_port()
            if candidate_port not in reserved_port_range:
                return candidate_port
909
910
    return _get_open_port()

youkaichao's avatar
youkaichao committed
911

912
913
def get_open_ports_list(count: int = 5) -> list[int]:
    """Get a list of open ports."""
914
    ports = set[int]()
915
916
917
918
919
    while len(ports) < count:
        ports.add(get_open_port())
    return list(ports)


920
def _get_open_port() -> int:
921
    port = envs.VLLM_PORT
922
    if port is not None:
923
924
925
926
927
928
929
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
930
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
931
932
933
934
935
936
937
938
939
940
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
941
942


943
def find_process_using_port(port: int) -> Optional[psutil.Process]:
944
945
946
947
948
949
950
    # TODO: We can not check for running processes with network
    # port on macOS. Therefore, we can not have a full graceful shutdown
    # of vLLM. For now, let's not look for processes in this case.
    # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
    if sys.platform.startswith("darwin"):
        return None

951
    our_pid = os.getpid()
952
    for conn in psutil.net_connections():
953
        if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
954
955
956
957
958
959
960
            try:
                return psutil.Process(conn.pid)
            except psutil.NoSuchProcess:
                return None
    return None


961
def update_environment_variables(envs: dict[str, str]):
962
    for k, v in envs.items():
963
        if k in os.environ and os.environ[k] != v:
964
            logger.warning(
965
966
967
968
969
                "Overwriting environment variable %s from '%s' to '%s'",
                k,
                os.environ[k],
                v,
            )
970
        os.environ[k] = v
971
972


973
def chunk_list(lst: list[T], chunk_size: int):
974
    """Yield successive chunk_size chunks from lst."""
975
    for i in range(0, len(lst), chunk_size):
976
        yield lst[i : i + chunk_size]
977
978
979
980
981
982
983


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


984
985
986
987
988
989
990
def next_power_of_2(n) -> int:
    """The next power of 2 (inclusive)"""
    if n < 1:
        return 1
    return 1 << (n - 1).bit_length()


991
992
993
994
995
996
997
def prev_power_of_2(n: int) -> int:
    """The previous power of 2 (inclusive)"""
    if n <= 0:
        return 0
    return 1 << (n.bit_length() - 1)


998
999
1000
1001
def round_up(x: int, y: int) -> int:
    return ((x + y - 1) // y) * y


1002
1003
1004
1005
def round_down(x: int, y: int) -> int:
    return (x // y) * y


1006
def _generate_random_fp8(
1007
    tensor: torch.Tensor,
1008
1009
1010
1011
1012
1013
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
1014
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
1015
    #     | E4M3        | E5M2
1016
    # -----|-------------|-------------------
1017
1018
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
1019
    from vllm import _custom_ops as ops
1020

1021
1022
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
1023
    ops.convert_fp8(tensor, tensor_tmp)
1024
1025
1026
    del tensor_tmp


1027
def get_kv_cache_torch_dtype(
1028
1029
1030
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
) -> torch.dtype:
1031
1032
    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
1033
            if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
1034
1035
1036
1037
1038
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
1039
        elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
1040
1041
1042
1043
1044
1045
1046
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    return torch_dtype


def create_kv_caches_with_random_flash(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
1058
    seed: Optional[int] = None,
1059
    device: Optional[str] = "cuda",
1060
    cache_layout: Optional[str] = "NHD",
1061
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
1062
    from vllm.platforms import current_platform
1063

1064
    current_platform.seed_everything(seed)
1065
1066

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
1067
1068
    generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
    assert cache_layout in ("NHD", "HND")
1069
    stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
1070

1071
    kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order)
1072
    scale = head_size**-0.5
1073

1074
1075
    key_caches: list[torch.Tensor] = []
    value_caches: list[torch.Tensor] = []
1076

1077
    for _ in range(num_layers):
1078
1079
1080
        key_value_cache = torch.empty(
            size=kv_cache_allocation_shape, dtype=torch_dtype, device=device
        ).permute(*stride_order)
1081
1082
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
            key_value_cache.uniform_(-scale, scale)
1083
        elif cache_dtype == "fp8":
1084
1085
            _generate_random_fp8(key_value_cache, -scale, scale)
        else:
1086
            raise ValueError(f"Does not support key cache of type {cache_dtype}")
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        key_caches.append(key_value_cache[:, 0])
        value_caches.append(key_value_cache[:, 1])
    return key_caches, value_caches


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
1100
    seed: Optional[int] = None,
1101
    device: Optional[str] = "cuda",
1102
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
Joe's avatar
Joe committed
1103
1104
1105
1106
    if cache_dtype == "fp8" and head_size % 16:
        raise ValueError(
            f"Does not support key cache of type fp8 with head_size {head_size}"
        )
1107
    from vllm.platforms import current_platform
1108

1109
    current_platform.seed_everything(seed)
1110
1111

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
1112
1113
1114
1115

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
1116
    key_caches: list[torch.Tensor] = []
1117
    for _ in range(num_layers):
1118
        key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
1119
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
1120
            key_cache.uniform_(-scale, scale)
1121
        elif cache_dtype == "fp8":
1122
            _generate_random_fp8(key_cache, -scale, scale)
1123
        else:
1124
            raise ValueError(f"Does not support key cache of type {cache_dtype}")
1125
1126
1127
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
1128
    value_caches: list[torch.Tensor] = []
1129
    for _ in range(num_layers):
1130
1131
1132
        value_cache = torch.empty(
            size=value_cache_shape, dtype=torch_dtype, device=device
        )
1133
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
1134
            value_cache.uniform_(-scale, scale)
1135
        elif cache_dtype == "fp8":
1136
            _generate_random_fp8(value_cache, -scale, scale)
1137
        else:
1138
            raise ValueError(f"Does not support value cache of type {cache_dtype}")
1139
1140
        value_caches.append(value_cache)
    return key_caches, value_caches
1141
1142


1143
@cache
1144
def is_pin_memory_available() -> bool:
1145
    from vllm.platforms import current_platform
1146

1147
    return current_platform.is_pin_memory_available()
1148
1149


1150
1151
1152
1153
1154
1155
1156
1157
@cache
def is_uva_available() -> bool:
    """Check if Unified Virtual Addressing (UVA) is available."""
    # UVA requires pinned memory.
    # TODO: Add more requirements for UVA if needed.
    return is_pin_memory_available()


1158
class DeviceMemoryProfiler:
1159
    def __init__(self, device: Optional[torch.types.Device] = None):
1160
1161
1162
1163
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
1164
        from vllm.platforms import current_platform
1165

1166
        gc.collect()
1167
        return current_platform.get_current_memory_usage(self.device)
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
1180
1181


1182
def make_ndarray_with_pad(
1183
    x: list[list[T]],
1184
1185
1186
1187
1188
1189
1190
    pad: T,
    dtype: npt.DTypeLike,
    *,
    max_len: Optional[int] = None,
) -> npt.NDArray:
    """
    Make a padded array from 2D inputs.
1191
1192
1193
1194

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
1195
1196
1197
1198
1199
    if max_len is None:
        # Unlike for most functions, map is faster than a genexpr over `len`
        max_len = max(map(len, x), default=0)

    padded_x = np.full((len(x), max_len), pad, dtype=dtype)
1200
1201
    for ind, blocktb in enumerate(x):
        assert len(blocktb) <= max_len
1202
        padded_x[ind, : len(blocktb)] = blocktb
1203
1204
1205
1206
1207

    return padded_x


def make_tensor_with_pad(
1208
    x: list[list[T]],
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    pad: T,
    dtype: torch.dtype,
    *,
    max_len: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    pin_memory: bool = False,
) -> torch.Tensor:
    """
    Make a padded tensor from 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
    padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

    tensor = torch.from_numpy(padded_x).to(device)
    if pin_memory:
        tensor = tensor.pin_memory()

    return tensor
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


1243
1244
1245
1246
1247
def get_dtype_size(dtype: torch.dtype) -> int:
    """Get the size of the data type in bytes."""
    return torch.tensor([], dtype=dtype).element_size()


1248
1249
1250
# bool = 0, int = 1, float = 2, complex = 3
def _get_precision_level(dtype: torch.dtype) -> int:
    # NOTE: Complex dtypes return `is_floating_point=False`
1251
    return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278


def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
    """
    Test whether it is lossless to cast a tensor from
    `src_dtype` to `tgt_dtype`.
    """
    if src_dtype == tgt_dtype:
        return True

    src_level = _get_precision_level(src_dtype)
    tgt_level = _get_precision_level(tgt_dtype)

    if src_level < tgt_level:
        return True
    if src_level > tgt_level:
        return False

    # Compare integral types
    if not src_dtype.is_floating_point and not src_dtype.is_complex:
        src_info = torch.iinfo(src_dtype)
        tgt_info = torch.iinfo(tgt_dtype)
        return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max

    # Compare floating-point types
    src_info = torch.finfo(src_dtype)
    tgt_info = torch.finfo(tgt_dtype)
1279
1280
1281
1282
1283
    return (
        src_info.min >= tgt_info.min
        and src_info.max <= tgt_info.max
        and src_info.resolution >= tgt_info.resolution
    )
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296


def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
    """
    Get the common `dtype` where all of the other `dtypes` can be
    cast to it without losing any information.
    """
    return max(
        dtypes,
        key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
    )


1297
1298
1299
1300
1301
def as_list(maybe_list: Iterable[T]) -> list[T]:
    """Convert iterable to list, unless it's already a list."""
    return maybe_list if isinstance(maybe_list, list) else list(maybe_list)


1302
1303
def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
    if isinstance(obj, str) or not isinstance(obj, Iterable):
1304
        return [obj]  # type: ignore[list-item]
1305
1306
1307
    return obj


1308
1309
1310
# `collections` helpers
def is_list_of(
    value: object,
1311
    typ: Union[type[T], tuple[type[T], ...]],
1312
1313
    *,
    check: Literal["first", "all"] = "first",
1314
) -> TypeIs[list[T]]:
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
    if not isinstance(value, list):
        return False

    if check == "first":
        return len(value) == 0 or isinstance(value[0], typ)
    elif check == "all":
        return all(isinstance(v, typ) for v in value)

    assert_never(check)


1326
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
1327
1328
1329
1330
    """Flatten a list of lists to a single list."""
    return [item for sublist in lists for item in sublist]


1331
1332
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
    """
1333
    Unlike [`itertools.groupby`][], groups are not broken by
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
    non-contiguous data.
    """
    groups = defaultdict[_K, list[_V]](list)

    for value in values:
        groups[key(value)].append(value)

    return groups.items()


1344
1345
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
1346
def init_cached_hf_modules() -> None:
1347
1348
1349
1350
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
1351

1352
    init_hf_modules()
1353
1354


1355
@cache
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
1371
    env_ld_library_path = envs.LD_LIBRARY_PATH
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


1383
def find_nccl_library() -> str:
1384
1385
1386
1387
1388
1389
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
1390
    so_file = envs.VLLM_NCCL_SO_PATH
1391
1392
1393
1394

    # manually load the nccl library
    if so_file:
        logger.info(
1395
1396
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
        )
1397
1398
    else:
        if torch.version.cuda is not None:
1399
            so_file = "libnccl.so.2"
1400
        elif torch.version.hip is not None:
1401
            so_file = "librccl.so.1"
1402
1403
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
1404
        logger.info("Found nccl from library %s", so_file)
1405
    return so_file
1406
1407


1408
1409
1410
def find_nccl_include_paths() -> Optional[list[str]]:
    """
    We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
1411
1412
    environment variable, or we find the library file brought by
    nvidia-nccl-cuXX. load_inline by default uses
1413
1414
1415
1416
1417
1418
1419
1420
1421
    torch.utils.cpp_extension.include_paths
    """
    paths: list[str] = []
    inc = envs.VLLM_NCCL_INCLUDE_PATH
    if inc and os.path.isdir(inc):
        paths.append(inc)

    try:
        import importlib.util
1422

1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        spec = importlib.util.find_spec("nvidia.nccl")
        if spec and getattr(spec, "submodule_search_locations", None):
            for loc in spec.submodule_search_locations:
                inc_dir = os.path.join(loc, "include")
                if os.path.exists(os.path.join(inc_dir, "nccl.h")):
                    paths.append(inc_dir)
    except Exception:
        pass

    seen = set()
    out: list[str] = []
    for p in paths:
        if p and p not in seen:
            out.append(p)
            seen.add(p)
    return out or None


youkaichao's avatar
youkaichao committed
1441
1442
prev_set_stream = torch.cuda.set_stream

1443
_current_stream_tls = threading.local()
youkaichao's avatar
youkaichao committed
1444
1445
1446


def _patched_set_stream(stream: torch.cuda.Stream) -> None:
1447
    _current_stream_tls.value = stream
youkaichao's avatar
youkaichao committed
1448
1449
1450
1451
1452
1453
    prev_set_stream(stream)


torch.cuda.set_stream = _patched_set_stream


1454
1455
1456
1457
1458
class _StreamPlaceholder:
    def __init__(self):
        self.synchronize = lambda: None


youkaichao's avatar
youkaichao committed
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
def current_stream() -> torch.cuda.Stream:
    """
    replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
    it turns out that `torch.cuda.current_stream()` is quite expensive,
    as it will construct a new stream object at each call.
    here we patch `torch.cuda.set_stream` to keep track of the current stream
    directly, so that we can avoid calling `torch.cuda.current_stream()`.

    the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
    from C/C++ code.
    """
1470
    from vllm.platforms import current_platform
1471
1472

    if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None:
youkaichao's avatar
youkaichao committed
1473
1474
        # when this function is called before any stream is set,
        # we return the default stream.
1475
1476
1477
        # On ROCm using the default 0 stream in combination with RCCL
        # is hurting performance. Therefore creating a dedicated stream
        # per process
1478
        if current_platform.is_rocm():
1479
1480
            # torch.cuda.set_stream here is the alias of _pathed_set_stream
            torch.cuda.set_stream(torch.cuda.Stream())
1481
1482
1483
1484
1485
1486
1487
1488
1489
        elif current_platform.is_cpu():
            _current_stream_tls.value = _StreamPlaceholder()
        else:
            current_stream = current_platform.current_stream
            if current_stream is not None:
                _current_stream_tls.value = current_stream()
            else:
                raise ValueError(
                    "Fail to set current stream, current platform "
1490
1491
                    "may not support current_stream with torch API"
                )
1492
    return _current_stream_tls.value
youkaichao's avatar
youkaichao committed
1493
1494


1495
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
1496
1497
1498
1499
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

1500
    if envs.VLLM_TRACE_FUNCTION:
1501
        tmp_dir = tempfile.gettempdir()
1502
1503
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
1504
1505
1506
1507
1508
1509
1510
1511
        filename = (
            f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
            f"_thread_{threading.get_ident()}_"
            f"at_{datetime.datetime.now()}.log"
        ).replace(" ", "_")
        log_path = os.path.join(
            tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
        )
1512
1513
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
1514
1515


1516
# `functools` helpers
1517
1518
def identity(value: T, **kwargs) -> T:
    """Returns the first provided value."""
1519
1520
1521
    return value


1522
F = TypeVar("F", bound=Callable[..., Any])
1523
1524


1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
def deprecate_args(
    start_index: int,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:
    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:
        params = inspect.signature(fn).parameters
        pos_types = (
            inspect.Parameter.POSITIONAL_ONLY,
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
        )
1539
        pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
1540
1541
1542
1543

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
1544
                deprecated_args = pos_kws[start_index : len(args)]
1545
1546
1547
                if deprecated_args:
                    msg = (
                        f"The positional arguments {deprecated_args} are "
1548
1549
                        "deprecated and will be removed in a future update."
                    )
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper


1565
def deprecate_kwargs(
1566
1567
1568
1569
    *kws: str,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
    deprecated_kws = set(kws)

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:
        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_kwargs = kwargs.keys() & deprecated_kws
                if deprecated_kwargs:
                    msg = (
                        f"The keyword arguments {deprecated_kwargs} are "
1583
1584
                        "deprecated and will be removed in a future update."
                    )
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper
1598
1599
1600


@lru_cache(maxsize=8)
1601
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda
    import torch.version

1612
    from vllm.platforms import current_platform
1613

1614
1615
    if not torch.cuda._is_compiled():
        return 0
1616
    if current_platform.is_rocm():
1617
1618
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
1619
1620
1621
1622
1623
        raw_count = (
            torch.cuda._device_count_amdsmi()
            if (hasattr(torch.cuda, "_device_count_amdsmi"))
            else -1
        )
1624
1625
1626
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1627
1628
1629
1630
1631
1632
    return r


def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.
1633

1634
1635
1636
1637
1638
1639
1640
    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
1641
1642


1643
1644
1645
1646
1647
1648
1649
def cuda_is_initialized() -> bool:
    """Check if CUDA is initialized."""
    if not torch.cuda._is_compiled():
        return False
    return torch.cuda.is_initialized()


1650
1651
1652
1653
1654
1655
1656
def xpu_is_initialized() -> bool:
    """Check if XPU is initialized."""
    if not torch.xpu._is_compiled():
        return False
    return torch.xpu.is_initialized()


1657
1658
1659
def cuda_get_device_properties(
    device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
1660
1661
1662
1663
1664
1665
1666
1667
1668
    """Get specified CUDA device property values without initializing CUDA in
    the current process."""
    if init_cuda or cuda_is_initialized():
        props = torch.cuda.get_device_properties(device)
        return tuple(getattr(props, name) for name in names)

    # Run in subprocess to avoid initializing CUDA as a side effect.
    mp_ctx = multiprocessing.get_context("fork")
    with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
1669
        return executor.submit(cuda_get_device_properties, device, names, True).result()
1670
1671


1672
1673
1674
def weak_bind(
    bound_method: Callable[..., Any],
) -> Callable[..., None]:
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


1688
1689
def run_once(f: Callable[P, None]) -> Callable[P, None]:
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
1690
1691
1692
1693
1694
1695
1696
        if wrapper.has_run:  # type: ignore[attr-defined]
            return

        with wrapper.lock:  # type: ignore[attr-defined]
            if not wrapper.has_run:  # type: ignore[attr-defined]
                wrapper.has_run = True  # type: ignore[attr-defined]
                return f(*args, **kwargs)
1697
1698

    wrapper.has_run = False  # type: ignore[attr-defined]
1699
    wrapper.lock = threading.Lock()  # type: ignore[attr-defined]
1700
    return wrapper
1701
1702


1703
class StoreBoolean(Action):
1704
1705
1706
1707
1708
1709
    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
1710
1711
1712
            raise ValueError(
                f"Invalid boolean value: {values}. Expected 'true' or 'false'."
            )
1713
1714


1715
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
1716
1717
    """SortedHelpFormatter that sorts arguments by their option strings."""

1718
1719
1720
1721
1722
1723
1724
    def _split_lines(self, text, width):
        """
        1. Sentences split across lines have their single newlines removed.
        2. Paragraphs and explicit newlines are split into separate lines.
        3. Each line is wrapped to the specified width (width of terminal).
        """
        # The patterns also include whitespace after the newline
1725
1726
        single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
        multiple_newlines = re.compile(r"\n{2,}\s*")
1727
        text = single_newline.sub(" ", text)
1728
1729
1730
        lines = re.split(multiple_newlines, text)
        return sum([textwrap.wrap(line, width) for line in lines], [])

1731
1732
    def add_arguments(self, actions):
        actions = sorted(actions, key=lambda x: x.option_strings)
1733
        super().add_arguments(actions)
1734
1735


1736
class FlexibleArgumentParser(ArgumentParser):
1737
1738
    """ArgumentParser that allows both underscore and dash in names."""

1739
    _deprecated: set[Action] = set()
1740
1741
1742
1743
1744
1745
1746
    _json_tip: str = (
        "When passing JSON CLI arguments, the following sets of arguments "
        "are equivalent:\n"
        '   --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
        "   --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
        "Additionally, list elements can be passed individually using +:\n"
        '   --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
1747
1748
        "   --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
    )
1749
    _search_keyword: Optional[str] = None
1750

1751
    def __init__(self, *args, **kwargs):
1752
1753
1754
        # Set the default "formatter_class" to SortedHelpFormatter
        if "formatter_class" not in kwargs:
            kwargs["formatter_class"] = SortedHelpFormatter
1755
1756
        # Pop kwarg "add_json_tip" to control whether to add the JSON tip
        self.add_json_tip = kwargs.pop("add_json_tip", True)
1757
1758
        super().__init__(*args, **kwargs)

1759
    if sys.version_info < (3, 13):
1760
        # Enable the deprecated kwarg for Python 3.12 and below
1761

1762
        def parse_known_args(self, args=None, namespace=None):
1763
1764
1765
1766
            if args is not None and "--disable-log-requests" in args:
                # Special case warning because the warning below won't trigger
                # if –-disable-log-requests because its value is default.
                logger.warning_once(
1767
1768
                    "argument '--disable-log-requests' is deprecated and "
                    "replaced with '--enable-log-requests'. This will be "
1769
1770
                    "removed in v0.12.0."
                )
1771
1772
            namespace, args = super().parse_known_args(args, namespace)
            for action in FlexibleArgumentParser._deprecated:
1773
1774
1775
1776
                if (
                    hasattr(namespace, dest := action.dest)
                    and getattr(namespace, dest) != action.default
                ):
1777
                    logger.warning_once("argument '%s' is deprecated", dest)
1778
1779
            return namespace, args

1780
1781
        def add_argument(self, *args, **kwargs):
            deprecated = kwargs.pop("deprecated", False)
1782
            action = super().add_argument(*args, **kwargs)
1783
1784
            if deprecated:
                FlexibleArgumentParser._deprecated.add(action)
1785
1786
            return action

1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
        class _FlexibleArgumentGroup(_ArgumentGroup):
            def add_argument(self, *args, **kwargs):
                deprecated = kwargs.pop("deprecated", False)
                action = super().add_argument(*args, **kwargs)
                if deprecated:
                    FlexibleArgumentParser._deprecated.add(action)
                return action

        def add_argument_group(self, *args, **kwargs):
            group = self._FlexibleArgumentGroup(self, *args, **kwargs)
            self._action_groups.append(group)
            return group
1799

1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
    def format_help(self):
        # Only use custom help formatting for bottom level parsers
        if self._subparsers is not None:
            return super().format_help()

        formatter = self._get_formatter()

        # Handle keyword search of the args
        if (search_keyword := self._search_keyword) is not None:
            # Normalise the search keyword
            search_keyword = search_keyword.lower().replace("_", "-")
            # Return full help if searching for 'all'
1812
            if search_keyword == "all":
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
                self.epilog = self._json_tip
                return super().format_help()

            # Return group help if searching for a group title
            for group in self._action_groups:
                if group.title and group.title.lower() == search_keyword:
                    formatter.start_section(group.title)
                    formatter.add_text(group.description)
                    formatter.add_arguments(group._group_actions)
                    formatter.end_section()
                    formatter.add_text(self._json_tip)
                    return formatter.format_help()

            # Return matched args if searching for an arg name
            matched_actions = []
            for group in self._action_groups:
                for action in group._group_actions:
                    # search option name
1831
1832
1833
                    if any(
                        search_keyword in opt.lower() for opt in action.option_strings
                    ):
1834
1835
                        matched_actions.append(action)
            if matched_actions:
1836
                formatter.start_section(f"Arguments matching '{search_keyword}'")
1837
1838
1839
1840
1841
1842
1843
1844
1845
                formatter.add_arguments(matched_actions)
                formatter.end_section()
                formatter.add_text(self._json_tip)
                return formatter.format_help()

            # No match found
            formatter.add_text(
                f"No group or arguments matching '{search_keyword}'.\n"
                "Use '--help' to see available groups or "
1846
1847
                "'--help=all' to see all available parameters."
            )
1848
1849
1850
            return formatter.format_help()

        # usage
1851
        formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872

        # description
        formatter.add_text(self.description)

        # positionals, optionals and user-defined groups
        formatter.start_section("Config Groups")
        config_groups = ""
        for group in self._action_groups:
            if not group._group_actions:
                continue
            title = group.title
            description = group.description or ""
            config_groups += f"{title: <24}{description}\n"
        formatter.add_text(config_groups)
        formatter.end_section()

        # epilog
        formatter.add_text(self.epilog)

        # determine help from format above
        return formatter.format_help()
1873

1874
1875
1876
1877
1878
    def parse_args(  # type: ignore[override]
        self,
        args: list[str] | None = None,
        namespace: Namespace | None = None,
    ):
1879
1880
1881
        if args is None:
            args = sys.argv[1:]

1882
1883
        # Check for --model in command line arguments first
        if args and args[0] == "serve":
1884
1885
            try:
                model_idx = next(
1886
1887
1888
1889
                    i
                    for i, arg in enumerate(args)
                    if arg == "--model" or arg.startswith("--model=")
                )
1890
                logger.warning(
1891
1892
                    "With `vllm serve`, you should provide the model as a "
                    "positional argument or in a config file instead of via "
1893
                    "the `--model` option. "
1894
1895
                    "The `--model` option will be removed in v0.13."
                )
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917

                if args[model_idx] == "--model":
                    model_tag = args[model_idx + 1]
                    rest_start_idx = model_idx + 2
                else:
                    model_tag = args[model_idx].removeprefix("--model=")
                    rest_start_idx = model_idx + 1

                # Move <model> to the front, e,g:
                # [Before]
                # vllm serve -tp 2 --model <model> --enforce-eager --port 8001
                # [After]
                # vllm serve <model> -tp 2 --enforce-eager --port 8001
                args = [
                    "serve",
                    model_tag,
                    *args[1:model_idx],
                    *args[rest_start_idx:],
                ]
                print("args", args)
            except StopIteration:
                pass
1918

1919
        if "--config" in args:
1920
            args = self._pull_args_from_config(args)
1921

1922
1923
1924
1925
1926
1927
1928
        def repl(match: re.Match) -> str:
            """Replaces underscores with dashes in the matched string."""
            return match.group(0).replace("_", "-")

        # Everything between the first -- and the first .
        pattern = re.compile(r"(?<=--)[^\.]*")

1929
        # Convert underscores to dashes and vice versa in argument names
1930
        processed_args = list[str]()
1931
        for i, arg in enumerate(args):
1932
            if arg.startswith("--help="):
1933
                FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
1934
                processed_args.append("--help")
1935
1936
1937
            elif arg.startswith("--"):
                if "=" in arg:
                    key, value = arg.split("=", 1)
1938
                    key = pattern.sub(repl, key, count=1)
1939
                    processed_args.append(f"{key}={value}")
1940
                else:
1941
1942
                    key = pattern.sub(repl, arg, count=1)
                    processed_args.append(key)
1943
            elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
1944
1945
1946
                # allow -O flag to be used without space, e.g. -O3 or -Odecode
                # -O.<...> handled later
                # also handle -O=<level> here
1947
1948
1949
1950
1951
1952
1953
                level = arg[3:] if arg[2] == "=" else arg[2:]
                processed_args.append(f"-O.level={level}")
            elif (
                arg == "-O"
                and i + 1 < len(args)
                and args[i + 1] in {"0", "1", "2", "3"}
            ):
1954
                # Convert -O <n> to -O.level <n>
1955
                processed_args.append("-O.level")
1956
1957
1958
            else:
                processed_args.append(arg)

1959
        def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
            """Creates a nested dictionary from a list of keys and a value.

            For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
            `{"a": {"b": {"c": 1}}}`
            """
            nested_dict: Any = value
            for key in reversed(keys):
                nested_dict = {key: nested_dict}
            return nested_dict

1970
1971
1972
        def recursive_dict_update(
            original: dict[str, Any],
            update: dict[str, Any],
1973
1974
1975
1976
1977
        ) -> set[str]:
            """Recursively updates a dictionary with another dictionary.
            Returns a set of duplicate keys that were overwritten.
            """
            duplicates = set[str]()
1978
1979
            for k, v in update.items():
                if isinstance(v, dict) and isinstance(original.get(k), dict):
1980
1981
1982
1983
                    nested_duplicates = recursive_dict_update(original[k], v)
                    duplicates |= {f"{k}.{d}" for d in nested_duplicates}
                elif isinstance(v, list) and isinstance(original.get(k), list):
                    original[k] += v
1984
                else:
1985
1986
                    if k in original:
                        duplicates.add(k)
1987
                    original[k] = v
1988
            return duplicates
1989

1990
1991
        delete = set[int]()
        dict_args = defaultdict[str, dict[str, Any]](dict)
1992
        duplicates = set[str]()
1993
        for i, processed_arg in enumerate(processed_args):
1994
1995
1996
1997
            if i in delete:  # skip if value from previous arg
                continue

            if processed_arg.startswith("-") and "." in processed_arg:
1998
                if "=" in processed_arg:
1999
                    processed_arg, value_str = processed_arg.split("=", 1)
2000
                    if "." not in processed_arg:
2001
                        # False positive, '.' was only in the value
2002
2003
                        continue
                else:
2004
                    value_str = processed_args[i + 1]
2005
                    delete.add(i + 1)
2006

2007
2008
2009
2010
                if processed_arg.endswith("+"):
                    processed_arg = processed_arg[:-1]
                    value_str = json.dumps(list(value_str.split(",")))

2011
                key, *keys = processed_arg.split(".")
2012
2013
2014
2015
2016
                try:
                    value = json.loads(value_str)
                except json.decoder.JSONDecodeError:
                    value = value_str

2017
2018
                # Merge all values with the same key into a single dict
                arg_dict = create_nested_dict(keys, value)
2019
2020
                arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
                duplicates |= {f"{key}.{d}" for d in arg_duplicates}
2021
2022
                delete.add(i)
        # Filter out the dict args we set to None
2023
        processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
2024
2025
2026
        if duplicates:
            logger.warning("Found duplicate keys %s", ", ".join(duplicates))

2027
2028
2029
2030
2031
        # Add the dict args back as if they were originally passed as JSON
        for dict_arg, dict_value in dict_args.items():
            processed_args.append(dict_arg)
            processed_args.append(json.dumps(dict_value))

2032
        return super().parse_args(processed_args, namespace)
2033

2034
2035
2036
2037
    def check_port(self, value):
        try:
            value = int(value)
        except ValueError:
2038
            msg = "Port must be an integer"
2039
            raise ArgumentTypeError(msg) from None
2040
2041

        if not (1024 <= value <= 65535):
2042
            raise ArgumentTypeError("Port must be between 1024 and 65535")
2043
2044
2045

        return value

2046
    def _pull_args_from_config(self, args: list[str]) -> list[str]:
2047
2048
        """Method to pull arguments specified in the config file
        into the command-line args variable.
2049
2050

        The arguments in config file will be inserted between
2051
        the argument list.
2052

2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
        example:
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        ```python
        $: vllm {serve,chat,complete} "facebook/opt-12B" \
            --config config.yaml -tp 2
        $: args = [
            "serve,chat,complete",
2063
2064
            "facebook/opt-12B",
            '--config', 'config.yaml',
2065
2066
2067
2068
            '-tp', '2'
        ]
        $: args = [
            "serve,chat,complete",
2069
2070
2071
            "facebook/opt-12B",
            '--port', '12323',
            '--tensor-parallel-size', '4',
2072
2073
2074
2075
2076
            '-tp', '2'
            ]
        ```

        Please note how the config args are inserted after the sub command.
2077
        this way the order of priorities is maintained when these are args
2078
2079
        parsed by super().
        """
2080
        assert args.count("--config") <= 1, "More than one config file specified!"
2081

2082
        index = args.index("--config")
2083
        if index == len(args) - 1:
2084
2085
2086
2087
            raise ValueError(
                "No config file specified! \
                             Please check your command-line arguments."
            )
2088
2089
2090

        file_path = args[index + 1]

2091
        config_args = self.load_config_file(file_path)
2092

2093
        # 0th index might be the sub command {serve,chat,complete,...}
2094
        # optionally followed by model_tag (only for serve)
2095
2096
2097
2098
        # followed by config args
        # followed by rest of cli args.
        # maintaining this order will enforce the precedence
        # of cli > config > defaults
2099
        if args[0].startswith("-"):
2100
            # No sub command (e.g., api_server entry point)
2101
            args = config_args + args[0:index] + args[index + 2 :]
2102
        elif args[0] == "serve":
2103
2104
            model_in_cli = len(args) > 1 and not args[1].startswith("-")
            model_in_config = any(arg == "--model" for arg in config_args)
2105
2106

            if not model_in_cli and not model_in_config:
2107
                raise ValueError(
2108
                    "No model specified! Please specify model either "
2109
2110
                    "as a positional argument or in a config file."
                )
2111
2112
2113

            if model_in_cli:
                # Model specified as positional arg, keep CLI version
2114
2115
2116
2117
2118
2119
2120
                args = (
                    [args[0]]
                    + [args[1]]
                    + config_args
                    + args[2:index]
                    + args[index + 2 :]
                )
2121
2122
            else:
                # No model in CLI, use config if available
2123
                args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
2124
        else:
2125
            args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
2126
2127
2128

        return args

2129
    def load_config_file(self, file_path: str) -> list[str]:
2130
        """Loads a yaml file and returns the key value pairs as a
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
        flattened list with argparse like pattern
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        returns:
            processed_args: list[str] = [
                '--port': '12323',
                '--tensor-parallel-size': '4'
            ]
        """
2142
2143
        extension: str = file_path.split(".")[-1]
        if extension not in ("yaml", "yml"):
2144
2145
            raise ValueError(
                "Config file must be of a yaml/yml type.\
2146
2147
2148
                              %s supplied",
                extension,
            )
2149
2150

        # only expecting a flat dictionary of atomic types
2151
        processed_args: list[str] = []
2152

2153
        config: dict[str, Union[int, str]] = {}
2154
        try:
2155
            with open(file_path) as config_file:
2156
2157
2158
2159
                config = yaml.safe_load(config_file)
        except Exception as ex:
            logger.error(
                "Unable to read the config file at %s. \
2160
2161
2162
                Make sure path is correct",
                file_path,
            )
2163
2164
            raise ex

2165
        store_boolean_arguments = [
2166
            action.dest for action in self._actions if isinstance(action, StoreBoolean)
2167
2168
        ]

2169
        for key, value in config.items():
2170
2171
            if isinstance(value, bool) and key not in store_boolean_arguments:
                if value:
2172
                    processed_args.append("--" + key)
2173
2174
            elif isinstance(value, list):
                if value:
2175
                    processed_args.append("--" + key)
2176
2177
                    for item in value:
                        processed_args.append(str(item))
2178
            else:
2179
                processed_args.append("--" + key)
2180
                processed_args.append(str(value))
2181
2182
2183

        return processed_args

2184

2185
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
2186
2187
2188
    """Utility function to run async task in a lock"""
    async with lock:
        return await task(*args, **kwargs)
2189
2190


2191
@lru_cache
2192
2193
2194
def supports_kw(
    callable: Callable[..., object],
    kw_name: str,
2195
    *,
2196
2197
2198
2199
2200
2201
    requires_kw_only: bool = False,
    allow_var_kwargs: bool = True,
) -> bool:
    """Check if a keyword is a valid kwarg for a callable; if requires_kw_only
    disallows kwargs names that can also be positional arguments.
    """
2202
    params = inspect.signature(callable).parameters
2203
2204
2205
2206
2207
2208
    if not params:
        return False

    param_val = params.get(kw_name)

    # Types where the it may be valid, i.e., explicitly defined & nonvariadic
2209
2210
2211
2212
2213
2214
2215
    passable_kw_types = set(
        (
            inspect.Parameter.POSITIONAL_ONLY,
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
            inspect.Parameter.KEYWORD_ONLY,
        )
    )
2216
2217
2218
2219

    if param_val:
        is_sig_param = param_val.kind in passable_kw_types
        # We want kwargs only, but this is passable as a positional arg
2220
2221
2222
2223
2224
        if (
            requires_kw_only
            and is_sig_param
            and param_val.kind != inspect.Parameter.KEYWORD_ONLY
        ):
2225
            return False
2226
2227
2228
        if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
            not requires_kw_only and is_sig_param
        ):
2229
2230
2231
2232
2233
2234
2235
2236
2237
            return True

    # If we're okay with var-kwargs, it's supported as long as
    # the kw_name isn't something like *args, **kwargs
    if allow_var_kwargs:
        # Get the last param; type is ignored here because params is a proxy
        # mapping, but it wraps an ordered dict, and they appear in order.
        # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
        last_param = params[next(reversed(params))]  # type: ignore
2238
2239
2240
2241
        return (
            last_param.kind == inspect.Parameter.VAR_KEYWORD
            and last_param.name != kw_name
        )
2242

2243
2244
2245
    return False


2246
2247
def get_allowed_kwarg_only_overrides(
    callable: Callable[..., object],
2248
    overrides: Optional[Mapping[str, object]],
2249
2250
    *,
    requires_kw_only: bool = True,
2251
    allow_var_kwargs: bool = False,
2252
) -> dict[str, Any]:
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
    """
    Given a callable which has one or more keyword only params and a dict
    mapping param names to values, drop values that can be not be kwarg
    expanded to overwrite one or more keyword-only args. This is used in a
    few places to handle custom processor overrides for multimodal models,
    e.g., for profiling when processor options provided by the user
    may affect the number of mm tokens per instance.

    Args:
        callable: Callable which takes 0 or more keyword only arguments.
2263
                  If None is provided, all overrides names are allowed.
2264
        overrides: Potential overrides to be used when invoking the callable.
2265
        allow_var_kwargs: Allows overrides that are expandable for var kwargs.
2266
2267
2268
2269
2270
2271
2272
2273
2274

    Returns:
        Dictionary containing the kwargs to be leveraged which may be used
        to overwrite one or more keyword only arguments when invoking the
        callable.
    """
    if not overrides:
        return {}

2275
2276
    # Drop any mm_processor_kwargs provided by the user that
    # are not kwargs, unless it can fit it var_kwargs param
2277
2278
2279
    filtered_overrides = {
        kwarg_name: val
        for kwarg_name, val in overrides.items()
2280
2281
2282
2283
2284
2285
        if supports_kw(
            callable,
            kwarg_name,
            requires_kw_only=requires_kw_only,
            allow_var_kwargs=allow_var_kwargs,
        )
2286
2287
2288
2289
2290
    }

    # If anything is dropped, log a warning
    dropped_keys = overrides.keys() - filtered_overrides.keys()
    if dropped_keys:
2291
2292
2293
        if requires_kw_only:
            logger.warning(
                "The following intended overrides are not keyword-only args "
2294
2295
2296
                "and will be dropped: %s",
                dropped_keys,
            )
2297
2298
2299
        else:
            logger.warning(
                "The following intended overrides are not keyword args "
2300
2301
2302
                "and will be dropped: %s",
                dropped_keys,
            )
2303
2304
2305
2306

    return filtered_overrides


2307
2308
2309
2310
2311
2312
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
    base_torch_version = Version(Version(torch.__version__).base_version)
    return base_torch_version >= Version("2.4.0")
2313
2314


2315
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
2316
def supports_xccl() -> bool:
2317
2318
2319
    return (
        is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
    )
2320
2321


2322
2323
2324
2325
2326
2327
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
2351
2352
2353


# Adapted from: https://stackoverflow.com/a/47212782/5082708
2354
class LazyDict(Mapping[str, T], Generic[T]):
2355
    def __init__(self, factory: dict[str, Callable[[], T]]):
2356
        self._factory = factory
2357
        self._dict: dict[str, T] = {}
2358

2359
    def __getitem__(self, key: str) -> T:
2360
2361
2362
2363
2364
2365
        if key not in self._dict:
            if key not in self._factory:
                raise KeyError(key)
            self._dict[key] = self._factory[key]()
        return self._dict[key]

2366
2367
2368
    def __setitem__(self, key: str, value: Callable[[], T]):
        self._factory[key] = value

2369
2370
2371
2372
2373
    def __iter__(self):
        return iter(self._factory)

    def __len__(self):
        return len(self._factory)
2374
2375


2376
2377
class ClassRegistry(UserDict[type[T], _V]):
    def __getitem__(self, key: type[T]) -> _V:
2378
2379
2380
2381
2382
2383
2384
        for cls in key.mro():
            if cls in self.data:
                return self.data[cls]

        raise KeyError(key)

    def __contains__(self, key: object) -> bool:
2385
2386
2387
        return self.contains(key)

    def contains(self, key: object, *, strict: bool = False) -> bool:
2388
2389
2390
        if not isinstance(key, type):
            return False

2391
2392
2393
        if strict:
            return key in self.data

2394
2395
2396
        return any(cls in self.data for cls in key.mro())


2397
def weak_ref_tensor(tensor: Any) -> Any:
2398
2399
2400
2401
2402
    """
    Create a weak reference to a tensor.
    The new tensor will share the same data as the original tensor,
    but will not keep the original tensor alive.
    """
2403
2404
2405
2406
    if isinstance(tensor, torch.Tensor):
        return torch.ops._C.weak_ref_tensor(tensor)
    else:
        return tensor
2407
2408
2409


def weak_ref_tensors(
2410
2411
2412
    tensors: Union[
        torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], IntermediateTensors
    ],
2413
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
    """
    Convenience function to create weak references to tensors,
    for single tensor, list of tensors or tuple of tensors.
    """
    if isinstance(tensors, torch.Tensor):
        return weak_ref_tensor(tensors)
    if isinstance(tensors, list):
        return [weak_ref_tensor(t) for t in tensors]
    if isinstance(tensors, tuple):
        return tuple(weak_ref_tensor(t) for t in tensors)
2424
2425
2426

    # For IntermediateTensors used in pipeline parallelism
    from vllm.sequence import IntermediateTensors
2427

2428
    if isinstance(tensors, IntermediateTensors):
2429
2430
2431
        ret = IntermediateTensors(
            {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()}
        )
2432
        return ret
2433
    raise ValueError("Invalid type for tensors")
2434
2435


2436
2437
2438
2439
2440
2441
2442
2443
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
    """
    Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
    """
    assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
    return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)


2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
    """
    Import a Python file according to its file path.

    Based on the official recipe:
    https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    """
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ModuleNotFoundError(f"No module named '{module_name}'")

    assert spec.loader is not None

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


2463
@cache
2464
2465
2466
2467
2468
2469
2470
def get_vllm_optional_dependencies():
    metadata = importlib.metadata.metadata("vllm")
    requirements = metadata.get_all("Requires-Dist", [])
    extras = metadata.get_all("Provides-Extra", [])

    return {
        extra: [
2471
2472
            re.split(r";|>=|<=|==", req)[0]
            for req in requirements
2473
2474
2475
2476
2477
2478
            if req.endswith(f'extra == "{extra}"')
        ]
        for extra in extras
    }


2479
2480
2481
2482
2483
class _PlaceholderBase:
    """
    Disallows downstream usage of placeholder modules.

    We need to explicitly override each dunder method because
2484
2485
    [`__getattr__`][vllm.utils._PlaceholderBase.__getattr__]
    is not called when they are accessed.
2486

2487
2488
    Info:
        [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
    """

    def __getattr__(self, key: str) -> Never:
        """
        The main class should implement this to throw an error
        for attribute accesses representing downstream usage.
        """
        raise NotImplementedError

    # [Basic customization]

    def __lt__(self, other: object):
        return self.__getattr__("__lt__")

    def __le__(self, other: object):
        return self.__getattr__("__le__")

    def __eq__(self, other: object):
        return self.__getattr__("__eq__")

    def __ne__(self, other: object):
        return self.__getattr__("__ne__")

    def __gt__(self, other: object):
        return self.__getattr__("__gt__")

    def __ge__(self, other: object):
        return self.__getattr__("__ge__")

    def __hash__(self):
        return self.__getattr__("__hash__")

    def __bool__(self):
        return self.__getattr__("__bool__")

    # [Callable objects]

    def __call__(self, *args: object, **kwargs: object):
        return self.__getattr__("__call__")

    # [Container types]

    def __len__(self):
        return self.__getattr__("__len__")

    def __getitem__(self, key: object):
        return self.__getattr__("__getitem__")

    def __setitem__(self, key: object, value: object):
        return self.__getattr__("__setitem__")

    def __delitem__(self, key: object):
        return self.__getattr__("__delitem__")

    # __missing__ is optional according to __getitem__ specification,
    # so it is skipped

    # __iter__ and __reversed__ have a default implementation
    # based on __len__ and __getitem__, so they are skipped.

    # [Numeric Types]

    def __add__(self, other: object):
        return self.__getattr__("__add__")

    def __sub__(self, other: object):
        return self.__getattr__("__sub__")

    def __mul__(self, other: object):
        return self.__getattr__("__mul__")

    def __matmul__(self, other: object):
        return self.__getattr__("__matmul__")

    def __truediv__(self, other: object):
        return self.__getattr__("__truediv__")

    def __floordiv__(self, other: object):
        return self.__getattr__("__floordiv__")

    def __mod__(self, other: object):
        return self.__getattr__("__mod__")

    def __divmod__(self, other: object):
        return self.__getattr__("__divmod__")

    def __pow__(self, other: object, modulo: object = ...):
        return self.__getattr__("__pow__")

    def __lshift__(self, other: object):
        return self.__getattr__("__lshift__")

    def __rshift__(self, other: object):
        return self.__getattr__("__rshift__")

    def __and__(self, other: object):
        return self.__getattr__("__and__")

    def __xor__(self, other: object):
        return self.__getattr__("__xor__")

    def __or__(self, other: object):
        return self.__getattr__("__or__")

    # r* and i* methods have lower priority than
    # the methods for left operand so they are skipped

    def __neg__(self):
        return self.__getattr__("__neg__")

    def __pos__(self):
        return self.__getattr__("__pos__")

    def __abs__(self):
        return self.__getattr__("__abs__")

    def __invert__(self):
        return self.__getattr__("__invert__")

    # __complex__, __int__ and __float__ have a default implementation
    # based on __index__, so they are skipped.

    def __index__(self):
        return self.__getattr__("__index__")

    def __round__(self, ndigits: object = ...):
        return self.__getattr__("__round__")

    def __trunc__(self):
        return self.__getattr__("__trunc__")

    def __floor__(self):
        return self.__getattr__("__floor__")

    def __ceil__(self):
        return self.__getattr__("__ceil__")

    # [Context managers]

    def __enter__(self):
        return self.__getattr__("__enter__")

    def __exit__(self, *args: object, **kwargs: object):
        return self.__getattr__("__exit__")


class PlaceholderModule(_PlaceholderBase):
2636
2637
2638
2639
    """
    A placeholder object to use when a module does not exist.

    This enables more informative errors when trying to access attributes
2640
    of a module that does not exist.
2641
    """
2642
2643
2644
2645
2646
2647

    def __init__(self, name: str) -> None:
        super().__init__()

        # Apply name mangling to avoid conflicting with module attributes
        self.__name = name
2648
2649
2650
2651
2652

    def placeholder_attr(self, attr_path: str):
        return _PlaceholderModuleAttr(self, attr_path)

    def __getattr__(self, key: str):
2653
        name = self.__name
2654
2655

        try:
2656
            importlib.import_module(name)
2657
2658
2659
2660
2661
2662
2663
2664
        except ImportError as exc:
            for extra, names in get_vllm_optional_dependencies().items():
                if name in names:
                    msg = f"Please install vllm[{extra}] for {extra} support"
                    raise ImportError(msg) from exc

            raise exc

2665
2666
2667
2668
        raise AssertionError(
            "PlaceholderModule should not be used "
            "when the original module can be imported"
        )
2669
2670


2671
2672
2673
2674
2675
2676
2677
class _PlaceholderModuleAttr(_PlaceholderBase):
    def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
        super().__init__()

        # Apply name mangling to avoid conflicting with module attributes
        self.__module = module
        self.__attr_path = attr_path
2678
2679

    def placeholder_attr(self, attr_path: str):
2680
        return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}")
2681
2682

    def __getattr__(self, key: str):
2683
        getattr(self.__module, f"{self.__attr_path}.{key}")
2684

2685
2686
2687
2688
        raise AssertionError(
            "PlaceholderModule should not be used "
            "when the original module can be imported"
        )
2689
2690


2691
2692
2693
2694
2695
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT")  # noqa


def direct_register_custom_op(
2696
2697
2698
2699
2700
2701
2702
    op_name: str,
    op_func: Callable,
    mutates_args: Optional[list[str]] = None,
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
    dispatch_key: Optional[str] = None,
    tags: tuple[torch.Tag, ...] = (),
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
):
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
2719
    if not supports_custom_op():
2720
        from vllm.platforms import current_platform
2721

2722
2723
2724
2725
2726
        assert not current_platform.is_cuda_alike(), (
            "cuda platform needs torch>=2.4 to support custom op, "
            "chances are you are using an old version of pytorch "
            "or a custom build of pytorch. It is recommended to "
            "use vLLM in a fresh new environment and let it install "
2727
2728
            "the required dependencies."
        )
2729
2730
        return

2731
2732
2733
2734
2735
    if mutates_args is None:
        mutates_args = []

    if dispatch_key is None:
        from vllm.platforms import current_platform
2736

2737
2738
        dispatch_key = current_platform.dispatch_key

2739
    import torch.library
2740

2741
    if hasattr(torch.library, "infer_schema"):
2742
        schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
2743
2744
2745
    else:
        # for pytorch 2.4
        import torch._custom_op.impl
2746

2747
        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
2748
    my_lib = target_lib or vllm_lib
2749
    my_lib.define(op_name + schema_str, tags=tags)
2750
    my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
2751
2752
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
2753
2754
2755
2756


def resolve_obj_by_qualname(qualname: str) -> Any:
    """
2757
    Resolve an object by its fully-qualified class name.
2758
2759
2760
2761
    """
    module_name, obj_name = qualname.rsplit(".", 1)
    module = importlib.import_module(module_name)
    return getattr(module, obj_name)
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786


def kill_process_tree(pid: int):
    """
    Kills all descendant processes of the given pid by sending SIGKILL.

    Args:
        pid (int): Process ID of the parent process
    """
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    # Get all children recursively
    children = parent.children(recursive=True)

    # Send SIGKILL to all children first
    for child in children:
        with contextlib.suppress(ProcessLookupError):
            os.kill(child.pid, signal.SIGKILL)

    # Finally kill the parent
    with contextlib.suppress(ProcessLookupError):
        os.kill(pid, signal.SIGKILL)
2787
2788
2789
2790
2791


@dataclass
class MemorySnapshot:
    """Memory snapshot."""
2792

2793
    torch_peak: int = 0
2794
2795
    free_memory: int = 0
    total_memory: int = 0
2796
2797
2798
    cuda_memory: int = 0
    torch_memory: int = 0
    non_torch_memory: int = 0
2799
    timestamp: float = 0.0
2800
2801
2802
2803
2804
    auto_measure: bool = True

    def __post_init__(self):
        if self.auto_measure:
            self.measure()
2805
2806

    def measure(self):
2807
2808
        from vllm.platforms import current_platform

2809
2810
2811
2812
2813
        # we measure the torch peak memory usage via allocated_bytes,
        # rather than `torch.cuda.memory_reserved()` .
        # After `torch.cuda.reset_peak_memory_stats()`,
        # `torch.cuda.memory_reserved()` will keep growing, and only shrink
        # when we call `torch.cuda.empty_cache()` or OOM happens.
2814
        self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
2815

2816
        self.free_memory, self.total_memory = torch.cuda.mem_get_info()
2817
2818
2819
2820
2821
        shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1))  # Orin, Thor, Spark
        if (
            current_platform.is_cuda()
            and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
        ):
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
            # On UMA (Orin, Thor and Spark) platform,
            # where both CPU and GPU rely on system memory,
            # the cudaMemGetInfo function shows the amount of free system memory
            # rather than what’s actually available.
            # In the case,
            # torch.cuda.mem_get_info() only reports "free" memory,
            # which can be lower than what is actually
            # available due to not including cache memory.
            # There’s also a comprehensive reference page
            # that explains how you can compute the proper value yourself.
            # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
            self.free_memory = psutil.virtual_memory().available

2835
        self.cuda_memory = self.total_memory - self.free_memory
2836

2837
2838
        # torch.cuda.memory_reserved() is how many bytes
        # PyTorch gets from cuda (by calling cudaMalloc, etc.)
2839
2840
2841
2842
        # this is used to measure the non-torch memory usage
        self.torch_memory = torch.cuda.memory_reserved()

        self.non_torch_memory = self.cuda_memory - self.torch_memory
2843
2844
        self.timestamp = time.time()

2845
    def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
2846
        return MemorySnapshot(
2847
            torch_peak=self.torch_peak - other.torch_peak,
2848
2849
            free_memory=self.free_memory - other.free_memory,
            total_memory=self.total_memory - other.total_memory,
2850
2851
2852
2853
2854
2855
            cuda_memory=self.cuda_memory - other.cuda_memory,
            torch_memory=self.torch_memory - other.torch_memory,
            non_torch_memory=self.non_torch_memory - other.non_torch_memory,
            timestamp=self.timestamp - other.timestamp,
            auto_measure=False,
        )
2856
2857
2858
2859


@dataclass
class MemoryProfilingResult:
2860
2861
    """Memory profiling result. All numbers are in bytes."""

2862
2863
2864
2865
2866
    non_kv_cache_memory: int = 0
    torch_peak_increase: int = 0
    non_torch_increase: int = 0
    weights_memory: float = 0
    before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
2867
2868
2869
2870
    before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
    after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
    profile_time: float = 0.0

2871
    def __repr__(self) -> str:
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
        return (
            f"Memory profiling takes {self.profile_time:.2f} seconds. "
            f"Total non KV cache memory: "
            f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
            f"torch peak memory increase: "
            f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
            f"non-torch forward increase memory: "
            f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
            f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
        )
2882

2883
2884
2885

@contextlib.contextmanager
def memory_profiling(
2886
2887
    baseline_snapshot: MemorySnapshot, weights_memory: int
) -> Generator[MemoryProfilingResult, None, None]:
2888
    """Memory profiling context manager.
2889
2890
    baseline_snapshot: the memory snapshot before the current vLLM instance.
    weights_memory: memory used by PyTorch when loading the model weights.
2891
2892
        Note that, before loading the model weights, we also initialize the device
        and distributed environment, which may consume some memory. This part is not
2893
        included in the weights_memory because PyTorch does not control it.
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927

    The memory in one GPU can be classified into 3 categories:
    1. memory used by anything other than the current vLLM instance.
    2. memory used by torch in the current vLLM instance.
    3. memory used in the current vLLM instance, but not by torch.

    A quantitive example:

    Before creating the current vLLM instance:
        category 1: 1 GiB
        category 2: 0 GiB
        category 3: 0 GiB

    After creating the current vLLM instance and loading the model,
    (i.e. before profiling):
        category 1: 1 GiB
        category 2: 2 GiB (model weights take 2 GiB)
        category 3: 0.5 GiB (memory used by NCCL)

    During profiling (peak):
        category 1: 1 GiB
        category 2: 4 GiB (peak activation tensors take 2 GiB)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    After profiling:
        category 1: 1 GiB
        category 2: 3 GiB (after garbage-collecting activation tensors)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    In this case, non-kv cache takes 5 GiB in total, including:
    a. 2 GiB used by the model weights (category 2)
    b. 2 GiB reserved for the peak activation tensors (category 2)
    c. 1 GiB used by non-torch components (category 3)

2928
    The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
2929

2930
    The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
2931

2932
    The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
2933
    """  # noqa
2934
2935
    gc.collect()
    torch.cuda.empty_cache()
2936
2937
2938
2939
    torch.cuda.reset_peak_memory_stats()

    result = MemoryProfilingResult()

2940
    result.before_create = baseline_snapshot
2941
    # the part of memory used for holding the model weights
2942
    result.weights_memory = weights_memory
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952

    result.before_profile.measure()

    yield result

    gc.collect()
    torch.cuda.empty_cache()

    result.after_profile.measure()

2953
2954
2955
2956
2957
    diff_profile = result.after_profile - result.before_profile
    diff_from_create = result.after_profile - result.before_create
    result.torch_peak_increase = diff_profile.torch_peak
    result.non_torch_increase = diff_from_create.non_torch_memory
    result.profile_time = diff_profile.timestamp
2958
2959
2960

    non_torch_memory = result.non_torch_increase
    peak_activation_memory = result.torch_peak_increase
2961
2962
2963
    result.non_kv_cache_memory = (
        non_torch_memory + peak_activation_memory + result.weights_memory
    )  # noqa
2964
2965


2966
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
2967
def set_ulimit(target_soft_limit=65535):
2968
    if sys.platform.startswith("win"):
2969
2970
2971
2972
        logger.info("Windows detected, skipping ulimit adjustment.")
        return

    import resource
2973

2974
2975
2976
2977
2978
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
2979
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
2980
2981
        except ValueError as e:
            logger.warning(
2982
2983
                "Found ulimit of %s and failed to automatically increase "
                "with error %s. This can cause fd limit errors like "
2984
                "`OSError: [Errno 24] Too many open files`. Consider "
2985
2986
2987
2988
                "increasing with ulimit -n",
                current_soft,
                e,
            )
2989
2990
2991
2992
2993
2994
2995
2996
2997


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


2998
def split_zmq_path(path: str) -> tuple[str, str, str]:
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
    """Split a zmq path into its parts."""
    parsed = urlparse(path)
    if not parsed.scheme:
        raise ValueError(f"Invalid zmq path: {path}")

    scheme = parsed.scheme
    host = parsed.hostname or ""
    port = str(parsed.port or "")

    if scheme == "tcp" and not all((host, port)):
        # The host and port fields are required for tcp
        raise ValueError(f"Invalid zmq path: {path}")

    if scheme != "tcp" and port:
        # port only makes sense with tcp
        raise ValueError(f"Invalid zmq path: {path}")

    return scheme, host, port


3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
    """Make a ZMQ path from its parts.

    Args:
        scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
        host: The host - can be an IPv4 address, IPv6 address, or hostname.
        port: Optional port number, only used for TCP sockets.

    Returns:
        A properly formatted ZMQ path string.
    """
3030
    if port is None:
3031
3032
3033
3034
3035
3036
        return f"{scheme}://{host}"
    if is_valid_ipv6_address(host):
        return f"{scheme}://[{host}]:{port}"
    return f"{scheme}://{host}:{port}"


3037
3038
3039
3040
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
    ctx: Union[zmq.asyncio.Context, zmq.Context],  # type: ignore[name-defined]
    path: str,
3041
    socket_type: Any,
3042
3043
    bind: Optional[bool] = None,
    identity: Optional[bytes] = None,
3044
    linger: Optional[int] = None,
3045
3046
3047
3048
) -> Union[zmq.Socket, zmq.asyncio.Socket]:  # type: ignore[name-defined]
    """Make a ZMQ socket with the proper bind/connect semantics."""

    mem = psutil.virtual_memory()
3049
    socket = ctx.socket(socket_type)
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062

    # Calculate buffer size based on system memory
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    # For systems with substantial memory (>32GB total, >16GB available):
    # - Set a large 0.5GB buffer to improve throughput
    # For systems with less memory:
    # - Use system default (-1) to avoid excessive memory consumption
    if total_mem > 32 and available_mem > 16:
        buf_size = int(0.5 * 1024**3)  # 0.5GB in bytes
    else:
        buf_size = -1  # Use system default buffer size

3063
    if bind is None:
3064
        bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076

    if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.RCVHWM, 0)
        socket.setsockopt(zmq.RCVBUF, buf_size)

    if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.SNDHWM, 0)
        socket.setsockopt(zmq.SNDBUF, buf_size)

    if identity is not None:
        socket.setsockopt(zmq.IDENTITY, identity)

3077
3078
3079
    if linger is not None:
        socket.setsockopt(zmq.LINGER, linger)

3080
3081
3082
    if socket_type == zmq.XPUB:
        socket.setsockopt(zmq.XPUB_VERBOSE, True)

3083
3084
3085
3086
3087
3088
    # Determine if the path is a TCP socket with an IPv6 address.
    # Enable IPv6 on the zmq socket if so.
    scheme, host, _ = split_zmq_path(path)
    if scheme == "tcp" and is_valid_ipv6_address(host):
        socket.setsockopt(zmq.IPV6, 1)

3089
    if bind:
3090
        socket.bind(path)
3091
    else:
3092
        socket.connect(path)
3093
3094
3095
3096
3097

    return socket


@contextlib.contextmanager
3098
3099
3100
def zmq_socket_ctx(
    path: str,
    socket_type: Any,
3101
    bind: Optional[bool] = None,
3102
    linger: int = 0,
3103
    identity: Optional[bytes] = None,
3104
) -> Iterator[zmq.Socket]:
3105
3106
    """Context manager for a ZMQ socket"""

3107
    ctx = zmq.Context()  # type: ignore[attr-defined]
3108
    try:
3109
        yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
3110
3111
3112
3113
    except KeyboardInterrupt:
        logger.debug("Got Keyboard Interrupt.")

    finally:
3114
        ctx.destroy(linger=linger)
3115
3116


3117
3118
3119
3120
3121
3122
3123
def _maybe_force_spawn():
    """Check if we need to force the use of the `spawn` multiprocessing start
    method.
    """
    if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
        return

3124
3125
    reasons = []
    if is_in_ray_actor():
3126
3127
3128
3129
        # even if we choose to spawn, we need to pass the ray address
        # to the subprocess so that it knows how to connect to the ray cluster.
        # env vars are inherited by subprocesses, even if we use spawn.
        import ray
3130

3131
        os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
3132
3133
3134
3135
3136
3137
        reasons.append("In a Ray actor and can only be spawned")

    if cuda_is_initialized():
        reasons.append("CUDA is initialized")
    elif xpu_is_initialized():
        reasons.append("XPU is initialized")
3138

3139
    if reasons:
3140
3141
3142
        logger.warning(
            "We must use the `spawn` multiprocessing start method. "
            "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
3143
            "See https://docs.vllm.ai/en/latest/usage/"
3144
            "troubleshooting.html#python-multiprocessing "
3145
3146
3147
            "for more information. Reasons: %s",
            "; ".join(reasons),
        )
3148
3149
3150
3151
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
3152
3153
3154
3155
3156
3157
3158
    """Get a multiprocessing context with a particular method (spawn or fork).
    By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
    determine the multiprocessing method (default is fork). However, under
    certain conditions, we may enforce spawn and override the value of
    VLLM_WORKER_MULTIPROC_METHOD.
    """
    _maybe_force_spawn()
3159
3160
    mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
    return multiprocessing.get_context(mp_method)
3161
3162
3163


def bind_kv_cache(
3164
3165
    ctx: dict[str, Any],
    kv_cache: list[list[torch.Tensor]],  # [virtual_engine][layer_index]
3166
    shared_kv_cache_layers: Optional[dict[str, str]] = None,
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
) -> None:
    # Bind the kv_cache tensor to Attention modules, similar to
    # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
    # Special things handled here:
    # 1. Some models have non-attention layers, e.g., Jamba
    # 2. Pipeline parallelism, each rank only has a subset of layers
    # 3. Encoder attention has no kv cache
    # 4. Encoder-decoder models, encoder-decoder attention and decoder-only
    #    attention of the same layer (e.g., bart's decoder.layers.1.self_attn
    #    and decoder.layers.1.encoder_attn) is mapped to the same kv cache
    #    tensor
3178
3179
3180
3181
    # 5. Some models have attention layers that share kv cache with previous
    #    layers, this is specified through shared_kv_cache_layers
    if shared_kv_cache_layers is None:
        shared_kv_cache_layers = {}
3182
3183
    from vllm.attention import AttentionType
    from vllm.model_executor.models.utils import extract_layer_index
3184

3185
    layer_need_kv_cache = [
3186
3187
3188
3189
3190
3191
3192
3193
        layer_name
        for layer_name in ctx
        if (
            hasattr(ctx[layer_name], "attn_type")
            and ctx[layer_name].attn_type
            in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
        )
        and ctx[layer_name].kv_sharing_target_layer_name is None
3194
3195
    ]
    layer_index_sorted = sorted(
3196
3197
        set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
    )
3198
    for layer_name in layer_need_kv_cache:
3199
        kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
3200
3201
3202
3203
        forward_ctx = ctx[layer_name]
        assert len(forward_ctx.kv_cache) == len(kv_cache)
        for ve, ve_kv_cache in enumerate(kv_cache):
            forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
3204
3205
    if shared_kv_cache_layers is not None:
        for layer_name, target_layer_name in shared_kv_cache_layers.items():
3206
3207
3208
            assert extract_layer_index(target_layer_name) < extract_layer_index(
                layer_name
            ), "v0 doesn't support interleaving kv sharing"
3209
            ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
3210
3211


3212
3213
3214
3215
3216
3217
def run_method(
    obj: Any,
    method: Union[str, bytes, Callable],
    args: tuple[Any],
    kwargs: dict[str, Any],
) -> Any:
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
3231
3232
3233
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
3234
3235
3236
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257


def import_pynvml():
    """
    Historical comments:

    libnvml.so is the library behind nvidia-smi, and
    pynvml is a Python wrapper around it. We use it to get GPU
    status without initializing CUDA context in the current process.
    Historically, there are two packages that provide pynvml:
    - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
        wrapper. It is a dependency of vLLM, and is installed when users
        install vLLM. It provides a Python module named `pynvml`.
    - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
        Prior to version 12.0, it also provides a Python module `pynvml`,
        and therefore conflicts with the official one. What's worse,
        the module is a Python package, and has higher priority than
        the official one which is a standalone Python file.
        This causes errors when both of them are installed.
        Starting from version 12.0, it migrates to a new module
        named `pynvml_utils` to avoid the conflict.
3258
3259
3260
3261
3262
3263
3264
    It is so confusing that many packages in the community use the
    unofficial one by mistake, and we have to handle this case.
    For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
    one, and it will cause errors, see the issue
    https://github.com/vllm-project/vllm/issues/12847 for example.
    After all the troubles, we decide to copy the official `pynvml`
    module to our codebase, and use it directly.
3265
    """
3266
    import vllm.third_party.pynvml as pynvml
3267

3268
    return pynvml
3269
3270


3271
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
    """
    A replacement for `abc.ABC`.
    When we use `abc.ABC`, subclasses will fail to instantiate
    if they do not implement all abstract methods.
    Here, we only require `raise NotImplementedError` in the
    base class, and log a warning if the method is not implemented
    in the subclass.
    """

    original_init = cls.__init__

    def find_unimplemented_methods(self: object):
        unimplemented_methods = []
        for attr_name in dir(self):
            # bypass inner method
3287
            if attr_name.startswith("_"):
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
                continue

            try:
                attr = getattr(self, attr_name)
                # get the func of callable method
                if callable(attr):
                    attr_func = attr.__func__
            except AttributeError:
                continue
            src = inspect.getsource(attr_func)
            if "NotImplementedError" in src:
                unimplemented_methods.append(attr_name)
        if unimplemented_methods:
3301
3302
            method_names = ",".join(unimplemented_methods)
            msg = f"Methods {method_names} not implemented in {self}"
3303
            logger.debug(msg)
3304
3305
3306
3307
3308
3309

    @wraps(original_init)
    def wrapped_init(self, *args, **kwargs) -> None:
        original_init(self, *args, **kwargs)
        find_unimplemented_methods(self)

3310
    type.__setattr__(cls, "__init__", wrapped_init)
3311
    return cls
3312
3313
3314
3315
3316
3317


class LazyLoader(types.ModuleType):
    """
    LazyLoader module borrowed from Tensorflow
    https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
3318
    with an addition of "module caching".
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362

    Lazily import a module, mainly to avoid pulling in large dependencies.
    Modules such as `xgrammar` might do additional side effects, so we
    only want to use this when it is needed, delaying all eager effects
    """

    def __init__(
        self,
        local_name: str,
        parent_module_globals: dict[str, Any],
        name: str,
    ):
        self._local_name = local_name
        self._parent_module_globals = parent_module_globals
        self._module: types.ModuleType | None = None

        super().__init__(str(name))

    def _load(self) -> types.ModuleType:
        # Import the target module and insert it into the parent's namespace
        try:
            module = importlib.import_module(self.__name__)
            self._parent_module_globals[self._local_name] = module
            # The additional add to sys.modules
            # ensures library is actually loaded.
            sys.modules[self._local_name] = module
        except ModuleNotFoundError as err:
            raise err from None

        # Update this object's dict so that if someone keeps a
        # reference to the LazyLoader, lookups are efficient
        # (__getattr__ is only called on lookups that fail).
        self.__dict__.update(module.__dict__)
        return module

    def __getattr__(self, item: Any) -> Any:
        if self._module is None:
            self._module = self._load()
        return getattr(self._module, item)

    def __dir__(self) -> list[str]:
        if self._module is None:
            self._module = self._load()
        return dir(self._module)
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378


def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
    """
    Helper function to swap values for two keys
    """
    v1 = obj.get(key1)
    v2 = obj.get(key2)
    if v1 is not None:
        obj[key2] = v1
    else:
        obj.pop(key2, None)
    if v2 is not None:
        obj[key1] = v2
    else:
        obj.pop(key1, None)
3379
3380
3381
3382
3383
3384
3385
3386


@contextlib.contextmanager
def cprofile_context(save_file: Optional[str] = None):
    """Run a cprofile

    Args:
        save_file: path to save the profile result. "1" or
3387
            None will result in printing to stdout.
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
    """
    import cProfile

    prof = cProfile.Profile()
    prof.enable()

    try:
        yield
    finally:
        prof.disable()
        if save_file and save_file != "1":
            prof.dump_stats(save_file)
        else:
            prof.print_stats(sort="cumtime")


def cprofile(save_file: Optional[str] = None, enabled: bool = True):
    """Decorator to profile a Python method using cProfile.

    Args:
        save_file: Path to save the profile result.
            If "1", None, or "", results will be printed to stdout.
        enabled: Set to false to turn this into a no-op
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if not enabled:
                # If profiling is disabled, just call the function directly.
                return func(*args, **kwargs)

            with cprofile_context(save_file):
                return func(*args, **kwargs)

        return wrapper

    return decorator
3426
3427


3428
3429
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
3430
    cfg = model_config.hf_text_config
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
    return (
        getattr(cfg, "alibi", False)  # Falcon
        or (
            "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
        )  # Bloom
        or getattr(cfg, "position_encoding_type", "") == "alibi"  # codellm_1b_alibi
        or (
            hasattr(cfg, "attn_config")  # MPT
            and (
                (
                    isinstance(cfg.attn_config, dict)
                    and cfg.attn_config.get("alibi", False)
                )
                or (
                    not isinstance(cfg.attn_config, dict)
                    and getattr(cfg.attn_config, "alibi", False)
                )
            )
        )
    )
3451
3452


3453
def sha256(input: Any) -> bytes:
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
    """Hash any picklable Python object using SHA-256.

    The input is serialized using pickle before hashing, which allows
    arbitrary Python objects to be used. Note that this function does
    not use a hash seed—if you need one, prepend it explicitly to the input.

    Args:
        input: Any picklable Python object.

    Returns:
3464
        Bytes representing the SHA-256 hash of the serialized input.
3465
3466
    """
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
3467
    return hashlib.sha256(input_bytes).digest()
3468
3469


3470
def sha256_cbor(input: Any) -> bytes:
3471
    """
3472
    Hash objects using CBOR serialization and SHA-256.
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482

    This option is useful for non-Python-dependent serialization and hashing.

    Args:
        input: Object to be serialized and hashed. Supported types include
            basic Python types and complex structures like lists, tuples, and
            dictionaries.
            Custom classes must implement CBOR serialization methods.

    Returns:
3483
        Bytes representing the SHA-256 hash of the CBOR serialized input.
3484
3485
    """
    input_bytes = cbor2.dumps(input, canonical=True)
3486
    return hashlib.sha256(input_bytes).digest()
3487
3488


3489
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
3490
3491
3492
3493
3494
3495
3496
3497
3498
    """Get a hash function by name, or raise an error if
    the function is not found.
    Args:
        hash_fn_name: Name of the hash function.
    Returns:
        A hash function.
    """
    if hash_fn_name == "sha256":
        return sha256
3499
3500
    if hash_fn_name == "sha256_cbor":
        return sha256_cbor
3501
3502
3503
3504

    raise ValueError(f"Unsupported hash function: {hash_fn_name}")


3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
def is_torch_equal_or_newer(target: str) -> bool:
    """Check if the installed torch version is >= the target version.

    Args:
        target: a version string, like "2.6.0".

    Returns:
        Whether the condition meets.
    """
    try:
3515
        return _is_torch_equal_or_newer(str(torch.__version__), target)
3516
3517
    except Exception:
        # Fallback to PKG-INFO to load the package info, needed by the doc gen.
3518
        return Version(importlib.metadata.version("torch")) >= Version(target)
3519
3520
3521
3522
3523
3524


# Helper function used in testing.
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
    torch_version = version.parse(torch_version)
    return torch_version >= version.parse(target)
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551


@cache
def _has_module(module_name: str) -> bool:
    """Return True if *module_name* can be found in the current environment.

    The result is cached so that subsequent queries for the same module incur
    no additional overhead.
    """
    return importlib.util.find_spec(module_name) is not None


def has_pplx() -> bool:
    """Whether the optional `pplx_kernels` package is available."""

    return _has_module("pplx_kernels")


def has_deep_ep() -> bool:
    """Whether the optional `deep_ep` package is available."""

    return _has_module("deep_ep")


def has_deep_gemm() -> bool:
    """Whether the optional `deep_gemm` package is available."""

3552
    return _has_module("deep_gemm")
3553
3554


3555
3556
3557
3558
3559
3560
def has_triton_kernels() -> bool:
    """Whether the optional `triton_kernels` package is available."""

    return _has_module("triton_kernels")


3561
3562
3563
3564
3565
3566
def has_tilelang() -> bool:
    """Whether the optional `tilelang` package is available."""

    return _has_module("tilelang")


3567
3568
3569
def set_process_title(
    name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
) -> None:
3570
3571
3572
    """
    Set the current process title to a specific name with an
    optional suffix.
3573
3574

    Args:
3575
        name: The title to assign to the current process.
3576
        suffix: An optional suffix to append to the base name.
3577
        prefix: A prefix to prepend to the front separated by `::`.
3578
3579
3580
    """
    if suffix:
        name = f"{name}_{suffix}"
3581
    setproctitle.setproctitle(f"{prefix}::{name}")
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595


def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
    """Prepend each output line with process-specific prefix"""

    prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
    file_write = file.write

    def write_with_prefix(s: str):
        if not s:
            return
        if file.start_new_line:  # type: ignore[attr-defined]
            file_write(prefix)
        idx = 0
3596
        while (next_idx := s.find("\n", idx)) != -1:
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
            next_idx += 1
            file_write(s[idx:next_idx])
            if next_idx == len(s):
                file.start_new_line = True  # type: ignore[attr-defined]
                return
            file_write(prefix)
            idx = next_idx
        file_write(s[idx:])
        file.start_new_line = False  # type: ignore[attr-defined]

    file.start_new_line = True  # type: ignore[attr-defined]
    file.write = write_with_prefix  # type: ignore[method-assign]


def decorate_logs(process_name: Optional[str] = None) -> None:
    """
    Adds a process-specific prefix to each line of output written to stdout and
    stderr.

    This function is intended to be called before initializing the api_server,
    engine_core, or worker classes, so that all subsequent output from the
    process is prefixed with the process name and PID. This helps distinguish
    log output from different processes in multi-process environments.

    Args:
        process_name: Optional; the name of the process to use in the prefix.
            If not provided, the current process name from the multiprocessing
            context is used.
    """
    if process_name is None:
        process_name = get_mp_context().current_process().name
    pid = os.getpid()
    _add_prefix(sys.stdout, process_name, pid)
    _add_prefix(sys.stderr, process_name, pid)
3631
3632
3633
3634
3635
3636


def length_from_prompt_token_ids_or_embeds(
    prompt_token_ids: Optional[list[int]],
    prompt_embeds: Optional[torch.Tensor],
) -> int:
3637
    """Calculate the request length (in number of tokens) give either
3638
3639
    prompt_token_ids or prompt_embeds.
    """
3640
3641
    prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
    prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
3642
3643
3644

    if prompt_token_len is None:
        if prompt_embeds_len is None:
3645
            raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
3646
3647
        return prompt_embeds_len
    else:
3648
        if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
3649
3650
3651
            raise ValueError(
                "Prompt token ids and prompt embeds had different lengths"
                f" prompt_token_ids={prompt_token_len}"
3652
3653
                f" prompt_embeds={prompt_embeds_len}"
            )
3654
        return prompt_token_len
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667


@contextlib.contextmanager
def set_env_var(key, value):
    old = os.environ.get(key)
    os.environ[key] = value
    try:
        yield
    finally:
        if old is None:
            del os.environ[key]
        else:
            os.environ[key] = old
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687


def unique_filepath(fn: Callable[[int], Path]) -> Path:
    """
    unique_filepath returns a unique path by trying
    to include an integer in increasing order.

    fn should be a callable that returns a path that
    includes the passed int at a fixed location.

    Note: This function has a TOCTOU race condition.
    Caller should use atomic operations (e.g., open with 'x' mode)
    when creating the file to ensure thread safety.
    """
    i = 0
    while True:
        p = fn(i)
        if not p.exists():
            return p
        i += 1