utils.py 69.3 KB
Newer Older
1
import argparse
2
import asyncio
3
import concurrent
4
import contextlib
5
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import enum
7
import gc
8
import getpass
9
import importlib.metadata
10
import importlib.util
11
import inspect
12
import ipaddress
13
import multiprocessing
14
import os
15
import re
16
import resource
17
import signal
18
import socket
19
import subprocess
20
import sys
21
22
import tempfile
import threading
23
import time
24
import traceback
Zhuohan Li's avatar
Zhuohan Li committed
25
import uuid
26
import warnings
27
import weakref
28
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
29
from collections import OrderedDict, UserDict, defaultdict
30
from collections.abc import Hashable, Iterable, Mapping
31
from dataclasses import dataclass, field
32
from functools import lru_cache, partial, wraps
33
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
34
35
36
                    Dict, Generator, Generic, Iterator, List, Literal,
                    NamedTuple, Optional, Tuple, Type, TypeVar, Union,
                    overload)
37
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
38

39
import numpy as np
40
import numpy.typing as npt
41
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
42
import torch
43
import torch.types
44
import yaml
45
46
import zmq
import zmq.asyncio
47
from packaging.version import Version
48
from torch.library import Library
49
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
50

51
import vllm.envs as envs
52
from vllm.logger import enable_trace_function_call, init_logger
53

54
55
56
if TYPE_CHECKING:
    from vllm.config import VllmConfig

57
58
logger = init_logger(__name__)

59
60
# Exception strings for non-implemented encoder/decoder scenarios

61
# Reminder: Please update docs/source/features/compatibility_matrix.md
62
63
# If the feature combo become valid

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
STR_NOT_IMPL_ENC_DEC_SWA = \
    "Sliding window attention for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
    "Prefix caching for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
    "Chunked prefill for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
    "Models with logits_soft_cap "
    "require FlashInfer backend, which is "
    "currently not supported for encoder/decoder "
    "models.")

STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
                             "supported with encoder/decoder "
                             "models.")

STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
                           "currently supported with "
                           "encoder/decoder models.")

STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
                           "supported with encoder/decoder "
                           "models.")

STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
                                 "currently supported with encoder/"
                                 "decoder models.")

98
99
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
                                "backends currently supported with encoder/"
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                                "decoder models.")

STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
                                       "currently supported with encoder/"
                                       "decoder models.")

# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
    "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
    "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
    "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
    STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
    "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
    "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
    "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
    "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
    "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
    "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
    "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}

# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

138
139
140
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""

141
142
143
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

144
145
146
147
STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
148
    "fp8": torch.uint8,
149
150
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
151
}
Zhuohan Li's avatar
Zhuohan Li committed
152

153
154
155
156
157
158
159
160
161
TORCH_DTYPE_TO_NUMPY_DTYPE = {
    torch.float16: np.float16,
    torch.float32: np.float32,
    torch.float64: np.float64,
    torch.uint8: np.uint8,
    torch.int32: np.int32,
    torch.int64: np.int64,
}

162
163
P = ParamSpec('P')
T = TypeVar("T")
164
U = TypeVar("U")
165

166
167
168
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")

Woosuk Kwon's avatar
Woosuk Kwon committed
169

170
171
172
173
174
175
176
class _Sentinel:
    ...


ALL_PINNED_SENTINEL = _Sentinel()


Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


182
183
184
185
186
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
189
190
191
class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
192
    def __next__(self) -> int:
193
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        self.counter += 1
195
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
199

200

201
202
203
204
205
206
207
208
209
210
211
212
class CacheInfo(NamedTuple):
    hits: int
    total: int

    @property
    def hit_ratio(self) -> float:
        if self.total == 0:
            return 0

        return self.hits / self.total


213
class LRUCache(Generic[_K, _V]):
214
    """Note: This class is not thread safe!"""
215

216
217
218
    def __init__(self, capacity: int) -> None:
        self.cache = OrderedDict[_K, _V]()
        self.pinned_items = set[_K]()
219
220
        self.capacity = capacity

221
222
223
        self._hits = 0
        self._total = 0

224
    def __contains__(self, key: _K) -> bool:
225
226
227
228
229
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

230
    def __getitem__(self, key: _K) -> _V:
231
232
233
        value = self.cache[key]  # Raise KeyError if not exists
        self.cache.move_to_end(key)
        return value
234

235
    def __setitem__(self, key: _K, value: _V) -> None:
236
237
        self.put(key, value)

238
    def __delitem__(self, key: _K) -> None:
239
240
        self.pop(key)

241
242
243
    def stat(self) -> CacheInfo:
        return CacheInfo(hits=self._hits, total=self._total)

244
    def touch(self, key: _K) -> None:
245
246
        self.cache.move_to_end(key)

247
248
    def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
        value: Optional[_V]
249
        if key in self.cache:
250
            value = self.cache[key]
251
            self.cache.move_to_end(key)
252
253

            self._hits += 1
254
        else:
255
            value = default
256
257

        self._total += 1
258
259
        return value

260
    def put(self, key: _K, value: _V) -> None:
261
262
263
264
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

265
    def pin(self, key: _K) -> None:
266
267
268
269
270
271
272
273
        """
        Pins a key in the cache preventing it from being
        evicted in the LRU order.
        """
        if key not in self.cache:
            raise ValueError(f"Cannot pin key: {key} not in cache.")
        self.pinned_items.add(key)

274
    def _unpin(self, key: _K) -> None:
275
276
        self.pinned_items.remove(key)

277
    def _on_remove(self, key: _K, value: Optional[_V]) -> None:
278
279
        pass

280
    def remove_oldest(self, *, remove_pinned: bool = False) -> None:
281
282
        if not self.cache:
            return
283
284
285
286
287
288
289
290
291
292
293

        if not remove_pinned:
            # pop the oldest item in the cache that is not pinned
            lru_key = next(
                (key for key in self.cache if key not in self.pinned_items),
                ALL_PINNED_SENTINEL)
            if lru_key is ALL_PINNED_SENTINEL:
                raise RuntimeError("All items are pinned, "
                                   "cannot remove oldest from the cache.")
        else:
            lru_key = next(iter(self.cache))
294
        self.pop(lru_key)  # type: ignore
295
296
297
298
299

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

300
    def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
301
        run_on_remove = key in self.cache
302
        value = self.cache.pop(key, default)
303
304
305
        # remove from pinned items
        if key in self.pinned_items:
            self._unpin(key)
306
307
308
309
        if run_on_remove:
            self._on_remove(key, value)
        return value

310
    def clear(self) -> None:
311
        while len(self.cache) > 0:
312
            self.remove_oldest(remove_pinned=True)
313
314
315
        self.cache.clear()


316
class PyObjectCache:
317
    """Used to cache python objects to avoid object allocations
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    across scheduler iterations.
    """

    def __init__(self, obj_builder):
        self._obj_builder = obj_builder
        self._index = 0

        self._obj_cache = []
        for _ in range(128):
            self._obj_cache.append(self._obj_builder())

    def _grow_cache(self):
        # Double the size of the cache
        num_objs = len(self._obj_cache)
        for _ in range(num_objs):
            self._obj_cache.append(self._obj_builder())

    def get_object(self):
336
        """Returns a pre-allocated cached object. If there is not enough
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        objects, then the cache size will double.
        """
        if self._index >= len(self._obj_cache):
            self._grow_cache()
            assert self._index < len(self._obj_cache)

        obj = self._obj_cache[self._index]
        self._index += 1

        return obj

    def reset(self):
        """Makes all cached-objects available for the next scheduler iteration.
        """
        self._index = 0


354
@lru_cache(maxsize=None)
355
356
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
357
    from vllm import _custom_ops as ops
358
    max_shared_mem = (
359
        ops.get_max_shared_memory_per_block_device_attribute(gpu))
360
361
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
362
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
363
364
365
    return int(max_shared_mem)


366
def get_cpu_memory() -> int:
367
    """Returns the total CPU memory of the node in bytes."""
368
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
369
370
371
372


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
373

374

375
376
377
378
def make_async(
    func: Callable[P, T],
    executor: Optional[concurrent.futures.Executor] = None
) -> Callable[P, Awaitable[T]]:
379
380
381
382
383
384
385
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

386
    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
387
388
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
389
        return loop.run_in_executor(executor=executor, func=p_func)
390
391
392
393

    return _async_wrapper


394
395
396
397
398
399
def _next_task(iterator: AsyncGenerator[T, None],
               loop: AbstractEventLoop) -> Task:
    # Can use anext() in python >= 3.10
    return loop.create_task(iterator.__anext__())  # type: ignore[arg-type]


400
async def merge_async_iterators(
401
402
    *iterators: AsyncGenerator[T,
                               None], ) -> AsyncGenerator[Tuple[int, T], None]:
403
404
405
406
407
408
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
    """
409

410
411
412
    loop = asyncio.get_running_loop()

    awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
413
414
    try:
        while awaits:
415
416
            done, _ = await asyncio.wait(awaits.keys(),
                                         return_when=FIRST_COMPLETED)
417
418
419
420
421
            for d in done:
                pair = awaits.pop(d)
                try:
                    item = await d
                    i, it = pair
422
                    awaits[_next_task(it, loop)] = pair
423
424
425
426
427
428
429
430
431
                    yield i, item
                except StopAsyncIteration:
                    pass
    finally:
        # Cancel any remaining iterators
        for f, (_, it) in awaits.items():
            with contextlib.suppress(BaseException):
                f.cancel()
                await it.aclose()
432
433


434
435
436
437
438
439
440
441
442
async def collect_from_async_generator(
        iterator: AsyncGenerator[T, None]) -> List[T]:
    """Collect all items from an async generator into a list."""
    items = []
    async for item in iterator:
        items.append(item)
    return items


443
def get_ip() -> str:
444
    host_ip = envs.VLLM_HOST_IP
445
446
447
448
449
450
451
    if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
        logger.warning(
            "The environment variable HOST_IP is deprecated and ignored, as"
            " it is often used by Docker and other software to"
            "interact with the container's network stack. Please"
            "use VLLM_HOST_IP instead to set the IP address for vLLM processes"
            " to communicate with each other.")
452
453
454
455
456
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

457
    # try ipv4
458
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
459
    try:
460
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
461
        return s.getsockname()[0]
462
463
464
465
466
    except Exception:
        pass

    # try ipv6
    try:
467
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
468
469
470
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
471
        return s.getsockname()[0]
472
473
474
475
476
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
477
478
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
479
480
        stacklevel=2)
    return "0.0.0.0"
481
482


483
484
485
486
487
488
489
490
def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False


491
def get_distributed_init_method(ip: str, port: int) -> str:
492
493
494
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
495
496


497
498
499
500
501
502
503
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


def get_open_port() -> int:
    port = envs.VLLM_PORT
504
    if port is not None:
505
506
507
508
509
510
511
512
513
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d",
                            port - 1, port)
514
515
516
517
518
519
520
521
522
523
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
524
525


526
def find_process_using_port(port: int) -> Optional[psutil.Process]:
527
528
529
530
531
532
533
    # TODO: We can not check for running processes with network
    # port on macOS. Therefore, we can not have a full graceful shutdown
    # of vLLM. For now, let's not look for processes in this case.
    # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
    if sys.platform.startswith("darwin"):
        return None

534
535
536
537
538
539
540
541
542
    for conn in psutil.net_connections():
        if conn.laddr.port == port:
            try:
                return psutil.Process(conn.pid)
            except psutil.NoSuchProcess:
                return None
    return None


543
544
def update_environment_variables(envs: Dict[str, str]):
    for k, v in envs.items():
545
        if k in os.environ and os.environ[k] != v:
546
547
548
            logger.warning(
                "Overwriting environment variable %s "
                "from '%s' to '%s'", k, os.environ[k], v)
549
        os.environ[k] = v
550
551


552
def chunk_list(lst: List[T], chunk_size: int):
553
    """Yield successive chunk_size chunks from lst."""
554
555
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]
556
557
558
559
560
561
562


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


563
def _generate_random_fp8(
564
    tensor: torch.Tensor,
565
566
567
568
569
570
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
571
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
572
573
574
575
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
576
    from vllm import _custom_ops as ops
577
578
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
579
    ops.convert_fp8(tensor, tensor_tmp)
580
581
582
    del tensor_tmp


583
584
585
def get_kv_cache_torch_dtype(
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
586
587
588
589
590
591
592
593
594
595
    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
596
        elif cache_dtype == "fp8":
597
598
599
600
601
602
603
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    return torch_dtype


def create_kv_caches_with_random_flash(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
618
    from vllm.platforms import current_platform
619
    current_platform.seed_everything(seed)
620
621
622
623

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
    key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
    scale = head_size**-0.5
624
625
626
627

    key_caches: List[torch.Tensor] = []
    value_caches: List[torch.Tensor] = []

628
629
630
631
    for _ in range(num_layers):
        key_value_cache = torch.empty(size=key_value_cache_shape,
                                      dtype=torch_dtype,
                                      device=device)
632
633
634
635
636
637
638
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
            key_value_cache.uniform_(-scale, scale)
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_value_cache, -scale, scale)
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        key_caches.append(key_value_cache[:, 0])
        value_caches.append(key_value_cache[:, 1])
    return key_caches, value_caches


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
Joe's avatar
Joe committed
655
656
657
658
659

    if cache_dtype == "fp8" and head_size % 16:
        raise ValueError(
            f"Does not support key cache of type fp8 with head_size {head_size}"
        )
660
    from vllm.platforms import current_platform
661
    current_platform.seed_everything(seed)
662
663

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
664
665
666
667

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
668
    key_caches: List[torch.Tensor] = []
669
670
671
672
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
673
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
674
            key_cache.uniform_(-scale, scale)
675
676
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
677
678
679
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
680
681
682
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
683
    value_caches: List[torch.Tensor] = []
684
685
686
687
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
688
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
689
            value_cache.uniform_(-scale, scale)
690
691
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
692
693
694
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
695
696
        value_caches.append(value_cache)
    return key_caches, value_caches
697
698


699
700
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
701
    from vllm.platforms import current_platform
702
    return current_platform.is_pin_memory_available()
703
704


705
class DeviceMemoryProfiler:
706

707
    def __init__(self, device: Optional[torch.types.Device] = None):
708
709
710
711
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
712
        from vllm.platforms import current_platform
713
        if current_platform.is_cuda_alike():
714
715
            torch.cuda.reset_peak_memory_stats(self.device)
            mem = torch.cuda.max_memory_allocated(self.device)
716
        elif current_platform.is_xpu():
717
718
            torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
            mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
719
720
721
722
723
724
725
726
727
728
729
730
731
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
732
733


734
735
736
737
738
739
740
741
742
def make_ndarray_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: npt.DTypeLike,
    *,
    max_len: Optional[int] = None,
) -> npt.NDArray:
    """
    Make a padded array from 2D inputs.
743
744
745
746

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
747
748
749
750
751
    if max_len is None:
        # Unlike for most functions, map is faster than a genexpr over `len`
        max_len = max(map(len, x), default=0)

    padded_x = np.full((len(x), max_len), pad, dtype=dtype)
752
753
754
    for ind, blocktb in enumerate(x):
        assert len(blocktb) <= max_len
        padded_x[ind, :len(blocktb)] = blocktb
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

    return padded_x


def make_tensor_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: torch.dtype,
    *,
    max_len: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    pin_memory: bool = False,
) -> torch.Tensor:
    """
    Make a padded tensor from 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
    padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

    tensor = torch.from_numpy(padded_x).to(device)
    if pin_memory:
        tensor = tensor.pin_memory()

    return tensor
782
783
784
785
786
787
788
789
790
791
792
793
794


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


795
796
797
798
799
def get_dtype_size(dtype: torch.dtype) -> int:
    """Get the size of the data type in bytes."""
    return torch.tensor([], dtype=dtype).element_size()


800
801
802
# `collections` helpers
def is_list_of(
    value: object,
803
    typ: Union[type[T], tuple[type[T], ...]],
804
805
806
807
808
809
810
811
812
813
814
815
816
817
    *,
    check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
    if not isinstance(value, list):
        return False

    if check == "first":
        return len(value) == 0 or isinstance(value[0], typ)
    elif check == "all":
        return all(isinstance(v, typ) for v in value)

    assert_never(check)


818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
                 Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: JSONTree[T],
) -> JSONTree[U]:
    ...


def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
    if isinstance(value, dict):
        return {k: json_map_leaves(func, v) for k, v in value.items()}
    elif isinstance(value, list):
        return [json_map_leaves(func, v) for v in value]
    elif isinstance(value, tuple):
        return tuple(json_map_leaves(func, v) for v in value)
    else:
        return func(value)


866
867
868
869
870
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
    """Flatten a list of lists to a single list."""
    return [item for sublist in lists for item in sublist]


871
872
873
874
875
876
877
878
879
880
881
882
883
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
    """
    Unlike :class:`itertools.groupby`, groups are not broken by
    non-contiguous data.
    """
    groups = defaultdict[_K, list[_V]](list)

    for value in values:
        groups[key(value)].append(value)

    return groups.items()


884
885
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
886
def init_cached_hf_modules() -> None:
887
888
889
890
891
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()
892
893


894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
910
    env_ld_library_path = envs.LD_LIBRARY_PATH
911
912
913
914
915
916
917
918
919
920
921
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


922
def find_nccl_library() -> str:
923
924
925
926
927
928
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
929
    so_file = envs.VLLM_NCCL_SO_PATH
930
931
932
933

    # manually load the nccl library
    if so_file:
        logger.info(
934
935
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
            so_file)
936
937
    else:
        if torch.version.cuda is not None:
938
            so_file = "libnccl.so.2"
939
        elif torch.version.hip is not None:
940
            so_file = "librccl.so.1"
941
942
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
943
        logger.info("Found nccl from library %s", so_file)
944
    return so_file
945
946


youkaichao's avatar
youkaichao committed
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
prev_set_stream = torch.cuda.set_stream

_current_stream = None


def _patched_set_stream(stream: torch.cuda.Stream) -> None:
    global _current_stream
    _current_stream = stream
    prev_set_stream(stream)


torch.cuda.set_stream = _patched_set_stream


def current_stream() -> torch.cuda.Stream:
    """
    replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
    it turns out that `torch.cuda.current_stream()` is quite expensive,
    as it will construct a new stream object at each call.
    here we patch `torch.cuda.set_stream` to keep track of the current stream
    directly, so that we can avoid calling `torch.cuda.current_stream()`.

    the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
    from C/C++ code.
    """
    global _current_stream
    if _current_stream is None:
        # when this function is called before any stream is set,
        # we return the default stream.
        _current_stream = torch.cuda.current_stream()
    return _current_stream


980
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
981
982
983
984
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

985
    if envs.VLLM_TRACE_FUNCTION:
986
        tmp_dir = tempfile.gettempdir()
987
988
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
989
990
991
        filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                    f"_thread_{threading.get_ident()}_"
                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
992
993
        log_path = os.path.join(tmp_dir, "vllm",
                                f"vllm-instance-{vllm_config.instance_id}",
994
995
996
                                filename)
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
997
998


999
# `functools` helpers
1000
1001
def identity(value: T, **kwargs) -> T:
    """Returns the first provided value."""
1002
1003
1004
1005
1006
1007
    return value


F = TypeVar('F', bound=Callable[..., Any])


1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def deprecate_args(
    start_index: int,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        params = inspect.signature(fn).parameters
        pos_types = (
            inspect.Parameter.POSITIONAL_ONLY,
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
        )
        pos_kws = [
            kw for kw, param in params.items() if param.kind in pos_types
        ]

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_args = pos_kws[start_index:len(args)]
                if deprecated_args:
                    msg = (
                        f"The positional arguments {deprecated_args} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper


1051
def deprecate_kwargs(
1052
1053
1054
1055
    *kws: str,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    deprecated_kws = set(kws)

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_kwargs = kwargs.keys() & deprecated_kws
                if deprecated_kwargs:
                    msg = (
                        f"The keyword arguments {deprecated_kwargs} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
        cuda_visible_devices: Optional[str] = None) -> int:
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda
    import torch.version

1099
    from vllm.platforms import current_platform
1100
1101
    if not torch.cuda._is_compiled():
        return 0
1102
    if current_platform.is_rocm():
1103
1104
1105
1106
1107
1108
1109
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
        raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
            torch.cuda, "_device_count_amdsmi")) else -1
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1110
1111
1112
1113
1114
1115
    return r


def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.
1116

1117
1118
1119
1120
1121
1122
1123
    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
1124
1125


1126
1127
1128
1129
1130
1131
1132
def cuda_is_initialized() -> bool:
    """Check if CUDA is initialized."""
    if not torch.cuda._is_compiled():
        return False
    return torch.cuda.is_initialized()


1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


1147
#From: https://stackoverflow.com/a/4104188/2749989
1148
def run_once(f: Callable[P, None]) -> Callable[P, None]:
1149

1150
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
1151
1152
1153
1154
1155
1156
        if not wrapper.has_run:  # type: ignore[attr-defined]
            wrapper.has_run = True  # type: ignore[attr-defined]
            return f(*args, **kwargs)

    wrapper.has_run = False  # type: ignore[attr-defined]
    return wrapper
1157
1158


1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1171
1172
1173
1174
1175
class SortedHelpFormatter(argparse.HelpFormatter):
    """SortedHelpFormatter that sorts arguments by their option strings."""

    def add_arguments(self, actions):
        actions = sorted(actions, key=lambda x: x.option_strings)
1176
        super().add_arguments(actions)
1177
1178


1179
1180
1181
class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

1182
1183
1184
1185
1186
1187
    def __init__(self, *args, **kwargs):
        # Set the default 'formatter_class' to SortedHelpFormatter
        if 'formatter_class' not in kwargs:
            kwargs['formatter_class'] = SortedHelpFormatter
        super().__init__(*args, **kwargs)

1188
1189
1190
1191
    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

1192
        if '--config' in args:
1193
            args = self._pull_args_from_config(args)
1194

1195
1196
1197
1198
        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith('--'):
1199
1200
1201
1202
1203
1204
1205
                if '=' in arg:
                    key, value = arg.split('=', 1)
                    key = '--' + key[len('--'):].replace('_', '-')
                    processed_args.append(f'{key}={value}')
                else:
                    processed_args.append('--' +
                                          arg[len('--'):].replace('_', '-'))
1206
1207
1208
1209
            elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
                # allow -O flag to be used without space, e.g. -O3
                processed_args.append('-O')
                processed_args.append(arg[2:])
1210
1211
1212
1213
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)
1214

1215
    def _pull_args_from_config(self, args: List[str]) -> List[str]:
1216
1217
        """Method to pull arguments specified in the config file
        into the command-line args variable.
1218
1219

        The arguments in config file will be inserted between
1220
        the argument list.
1221

1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        example:
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        ```python
        $: vllm {serve,chat,complete} "facebook/opt-12B" \
            --config config.yaml -tp 2
        $: args = [
            "serve,chat,complete",
1232
1233
            "facebook/opt-12B",
            '--config', 'config.yaml',
1234
1235
1236
1237
            '-tp', '2'
        ]
        $: args = [
            "serve,chat,complete",
1238
1239
1240
            "facebook/opt-12B",
            '--port', '12323',
            '--tensor-parallel-size', '4',
1241
1242
1243
1244
1245
            '-tp', '2'
            ]
        ```

        Please note how the config args are inserted after the sub command.
1246
        this way the order of priorities is maintained when these are args
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
        parsed by super().
        """
        assert args.count(
            '--config') <= 1, "More than one config file specified!"

        index = args.index('--config')
        if index == len(args) - 1:
            raise ValueError("No config file specified! \
                             Please check your command-line arguments.")

        file_path = args[index + 1]

1259
        config_args = self._load_config_file(file_path)
1260
1261

        # 0th index is for {serve,chat,complete}
1262
        # followed by model_tag (only for serve)
1263
1264
1265
1266
        # followed by config args
        # followed by rest of cli args.
        # maintaining this order will enforce the precedence
        # of cli > config > defaults
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        if args[0] == "serve":
            if index == 1:
                raise ValueError(
                    "No model_tag specified! Please check your command-line"
                    " arguments.")
            args = [args[0]] + [
                args[1]
            ] + config_args + args[2:index] + args[index + 2:]
        else:
            args = [args[0]] + config_args + args[1:index] + args[index + 2:]
1277
1278
1279

        return args

1280
    def _load_config_file(self, file_path: str) -> List[str]:
1281
        """Loads a yaml file and returns the key value pairs as a
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        flattened list with argparse like pattern
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        returns:
            processed_args: list[str] = [
                '--port': '12323',
                '--tensor-parallel-size': '4'
            ]
1292

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
        """

        extension: str = file_path.split('.')[-1]
        if extension not in ('yaml', 'yml'):
            raise ValueError(
                "Config file must be of a yaml/yml type.\
                              %s supplied", extension)

        # only expecting a flat dictionary of atomic types
        processed_args: List[str] = []

        config: Dict[str, Union[int, str]] = {}
        try:
1306
            with open(file_path) as config_file:
1307
1308
1309
1310
1311
1312
1313
                config = yaml.safe_load(config_file)
        except Exception as ex:
            logger.error(
                "Unable to read the config file at %s. \
                Make sure path is correct", file_path)
            raise ex

1314
1315
1316
1317
1318
        store_boolean_arguments = [
            action.dest for action in self._actions
            if isinstance(action, StoreBoolean)
        ]

1319
        for key, value in config.items():
1320
1321
1322
1323
1324
1325
            if isinstance(value, bool) and key not in store_boolean_arguments:
                if value:
                    processed_args.append('--' + key)
            else:
                processed_args.append('--' + key)
                processed_args.append(str(value))
1326
1327
1328

        return processed_args

1329
1330
1331
1332
1333
1334

async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
                              **kwargs):
    """Utility function to run async task in a lock"""
    async with lock:
        return await task(*args, **kwargs)
1335
1336


1337
1338
1339
def supports_kw(
    callable: Callable[..., object],
    kw_name: str,
1340
    *,
1341
1342
1343
1344
1345
1346
    requires_kw_only: bool = False,
    allow_var_kwargs: bool = True,
) -> bool:
    """Check if a keyword is a valid kwarg for a callable; if requires_kw_only
    disallows kwargs names that can also be positional arguments.
    """
1347
    params = inspect.signature(callable).parameters
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
    if not params:
        return False

    param_val = params.get(kw_name)

    # Types where the it may be valid, i.e., explicitly defined & nonvariadic
    passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
                             inspect.Parameter.POSITIONAL_OR_KEYWORD,
                             inspect.Parameter.KEYWORD_ONLY))

    if param_val:
        is_sig_param = param_val.kind in passable_kw_types
        # We want kwargs only, but this is passable as a positional arg
        if (requires_kw_only and is_sig_param
                and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
            return False
        if ((requires_kw_only
             and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
                or (not requires_kw_only and is_sig_param)):
            return True

    # If we're okay with var-kwargs, it's supported as long as
    # the kw_name isn't something like *args, **kwargs
    if allow_var_kwargs:
        # Get the last param; type is ignored here because params is a proxy
        # mapping, but it wraps an ordered dict, and they appear in order.
        # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
        last_param = params[next(reversed(params))]  # type: ignore
        return (last_param.kind == inspect.Parameter.VAR_KEYWORD
                and last_param.name != kw_name)
    return False


def resolve_mm_processor_kwargs(
1382
1383
    init_kwargs: Optional[Mapping[str, object]],
    inference_kwargs: Optional[Mapping[str, object]],
1384
    callable: Callable[..., object],
1385
1386
    *,
    requires_kw_only: bool = True,
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
    allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
    """Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
    those who are not explicit keywords to the given callable (of one is
    given; otherwise no filtering is done), then merges the kwarg dicts,
    giving priority to inference_kwargs if there are any collisions.

    In the case that no kwarg overrides are provided, returns an empty
    dict so that it can still be kwarg expanded into the callable later on.

    If allow_var_kwargs=True, allows for things that can be expanded into
    kwargs as long as they aren't naming collision for var_kwargs or potential
    positional arguments.
    """
    # Filter inference time multimodal processor kwargs provided
    runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
        callable,
        overrides=inference_kwargs,
1405
1406
1407
        requires_kw_only=requires_kw_only,
        allow_var_kwargs=allow_var_kwargs,
    )
1408
1409
1410

    # Filter init time multimodal processor kwargs provided
    init_mm_kwargs = get_allowed_kwarg_only_overrides(
1411
1412
1413
1414
1415
        callable,
        overrides=init_kwargs,
        requires_kw_only=requires_kw_only,
        allow_var_kwargs=allow_var_kwargs,
    )
1416

1417
1418
1419
1420
    # Merge the final processor kwargs, prioritizing inference
    # time values over the initialization time values.
    mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
    return mm_processor_kwargs
1421
1422


1423
1424
def get_allowed_kwarg_only_overrides(
    callable: Callable[..., object],
1425
    overrides: Optional[Mapping[str, object]],
1426
1427
    *,
    requires_kw_only: bool = True,
1428
    allow_var_kwargs: bool = False,
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
) -> Dict[str, Any]:
    """
    Given a callable which has one or more keyword only params and a dict
    mapping param names to values, drop values that can be not be kwarg
    expanded to overwrite one or more keyword-only args. This is used in a
    few places to handle custom processor overrides for multimodal models,
    e.g., for profiling when processor options provided by the user
    may affect the number of mm tokens per instance.

    Args:
        callable: Callable which takes 0 or more keyword only arguments.
1440
                  If None is provided, all overrides names are allowed.
1441
        overrides: Potential overrides to be used when invoking the callable.
1442
        allow_var_kwargs: Allows overrides that are expandable for var kwargs.
1443
1444
1445
1446
1447
1448
1449
1450
1451

    Returns:
        Dictionary containing the kwargs to be leveraged which may be used
        to overwrite one or more keyword only arguments when invoking the
        callable.
    """
    if not overrides:
        return {}

1452
1453
    # Drop any mm_processor_kwargs provided by the user that
    # are not kwargs, unless it can fit it var_kwargs param
1454
1455
1456
    filtered_overrides = {
        kwarg_name: val
        for kwarg_name, val in overrides.items()
1457
1458
        if supports_kw(callable,
                       kwarg_name,
1459
                       requires_kw_only=requires_kw_only,
1460
                       allow_var_kwargs=allow_var_kwargs)
1461
1462
1463
1464
1465
    }

    # If anything is dropped, log a warning
    dropped_keys = overrides.keys() - filtered_overrides.keys()
    if dropped_keys:
1466
1467
1468
1469
1470
1471
1472
1473
        if requires_kw_only:
            logger.warning(
                "The following intended overrides are not keyword-only args "
                "and and will be dropped: %s", dropped_keys)
        else:
            logger.warning(
                "The following intended overrides are not keyword args "
                "and and will be dropped: %s", dropped_keys)
1474
1475
1476
1477

    return filtered_overrides


1478
1479
1480
1481
1482
1483
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
    base_torch_version = Version(Version(torch.__version__).base_version)
    return base_torch_version >= Version("2.4.0")
1484
1485


1486
1487
1488
1489
1490
1491
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
1515
1516
1517


# Adapted from: https://stackoverflow.com/a/47212782/5082708
1518
class LazyDict(Mapping[str, T], Generic[T]):
1519
1520
1521
1522
1523

    def __init__(self, factory: Dict[str, Callable[[], T]]):
        self._factory = factory
        self._dict: Dict[str, T] = {}

1524
    def __getitem__(self, key: str) -> T:
1525
1526
1527
1528
1529
1530
        if key not in self._dict:
            if key not in self._factory:
                raise KeyError(key)
            self._dict[key] = self._factory[key]()
        return self._dict[key]

1531
1532
1533
    def __setitem__(self, key: str, value: Callable[[], T]):
        self._factory[key] = value

1534
1535
1536
1537
1538
    def __iter__(self):
        return iter(self._factory)

    def __len__(self):
        return len(self._factory)
1539
1540


1541
class ClassRegistry(UserDict[Type[T], _V]):
1542

1543
    def __getitem__(self, key: Type[T]) -> _V:
1544
1545
1546
1547
1548
1549
1550
        for cls in key.mro():
            if cls in self.data:
                return self.data[cls]

        raise KeyError(key)

    def __contains__(self, key: object) -> bool:
1551
1552
1553
        return self.contains(key)

    def contains(self, key: object, *, strict: bool = False) -> bool:
1554
1555
1556
        if not isinstance(key, type):
            return False

1557
1558
1559
        if strict:
            return key in self.data

1560
1561
1562
        return any(cls in self.data for cls in key.mro())


1563
1564
1565
1566
1567
1568
1569
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """
    Create a weak reference to a tensor.
    The new tensor will share the same data as the original tensor,
    but will not keep the original tensor alive.
    """
    return torch.ops._C.weak_ref_tensor(tensor)
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585


def weak_ref_tensors(
    tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
    """
    Convenience function to create weak references to tensors,
    for single tensor, list of tensors or tuple of tensors.
    """
    if isinstance(tensors, torch.Tensor):
        return weak_ref_tensor(tensors)
    if isinstance(tensors, list):
        return [weak_ref_tensor(t) for t in tensors]
    if isinstance(tensors, tuple):
        return tuple(weak_ref_tensor(t) for t in tensors)
    raise ValueError("Invalid type for tensors")
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595


def is_in_doc_build() -> bool:
    try:
        from sphinx.ext.autodoc.mock import _MockModule
        return isinstance(torch, _MockModule)
    except ModuleNotFoundError:
        return False


1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
    """
    Import a Python file according to its file path.

    Based on the official recipe:
    https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    """
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ModuleNotFoundError(f"No module named '{module_name}'")

    assert spec.loader is not None

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
@lru_cache(maxsize=None)
def get_vllm_optional_dependencies():
    metadata = importlib.metadata.metadata("vllm")
    requirements = metadata.get_all("Requires-Dist", [])
    extras = metadata.get_all("Provides-Extra", [])

    return {
        extra: [
            re.split(r";|>=|<=|==", req)[0] for req in requirements
            if req.endswith(f'extra == "{extra}"')
        ]
        for extra in extras
    }


1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
class _PlaceholderBase:
    """
    Disallows downstream usage of placeholder modules.

    We need to explicitly override each dunder method because
    :meth:`__getattr__` is not called when they are accessed.

    See also:
        [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
    """

    def __getattr__(self, key: str) -> Never:
        """
        The main class should implement this to throw an error
        for attribute accesses representing downstream usage.
        """
        raise NotImplementedError

    # [Basic customization]

    def __lt__(self, other: object):
        return self.__getattr__("__lt__")

    def __le__(self, other: object):
        return self.__getattr__("__le__")

    def __eq__(self, other: object):
        return self.__getattr__("__eq__")

    def __ne__(self, other: object):
        return self.__getattr__("__ne__")

    def __gt__(self, other: object):
        return self.__getattr__("__gt__")

    def __ge__(self, other: object):
        return self.__getattr__("__ge__")

    def __hash__(self):
        return self.__getattr__("__hash__")

    def __bool__(self):
        return self.__getattr__("__bool__")

    # [Callable objects]

    def __call__(self, *args: object, **kwargs: object):
        return self.__getattr__("__call__")

    # [Container types]

    def __len__(self):
        return self.__getattr__("__len__")

    def __getitem__(self, key: object):
        return self.__getattr__("__getitem__")

    def __setitem__(self, key: object, value: object):
        return self.__getattr__("__setitem__")

    def __delitem__(self, key: object):
        return self.__getattr__("__delitem__")

    # __missing__ is optional according to __getitem__ specification,
    # so it is skipped

    # __iter__ and __reversed__ have a default implementation
    # based on __len__ and __getitem__, so they are skipped.

    # [Numeric Types]

    def __add__(self, other: object):
        return self.__getattr__("__add__")

    def __sub__(self, other: object):
        return self.__getattr__("__sub__")

    def __mul__(self, other: object):
        return self.__getattr__("__mul__")

    def __matmul__(self, other: object):
        return self.__getattr__("__matmul__")

    def __truediv__(self, other: object):
        return self.__getattr__("__truediv__")

    def __floordiv__(self, other: object):
        return self.__getattr__("__floordiv__")

    def __mod__(self, other: object):
        return self.__getattr__("__mod__")

    def __divmod__(self, other: object):
        return self.__getattr__("__divmod__")

    def __pow__(self, other: object, modulo: object = ...):
        return self.__getattr__("__pow__")

    def __lshift__(self, other: object):
        return self.__getattr__("__lshift__")

    def __rshift__(self, other: object):
        return self.__getattr__("__rshift__")

    def __and__(self, other: object):
        return self.__getattr__("__and__")

    def __xor__(self, other: object):
        return self.__getattr__("__xor__")

    def __or__(self, other: object):
        return self.__getattr__("__or__")

    # r* and i* methods have lower priority than
    # the methods for left operand so they are skipped

    def __neg__(self):
        return self.__getattr__("__neg__")

    def __pos__(self):
        return self.__getattr__("__pos__")

    def __abs__(self):
        return self.__getattr__("__abs__")

    def __invert__(self):
        return self.__getattr__("__invert__")

    # __complex__, __int__ and __float__ have a default implementation
    # based on __index__, so they are skipped.

    def __index__(self):
        return self.__getattr__("__index__")

    def __round__(self, ndigits: object = ...):
        return self.__getattr__("__round__")

    def __trunc__(self):
        return self.__getattr__("__trunc__")

    def __floor__(self):
        return self.__getattr__("__floor__")

    def __ceil__(self):
        return self.__getattr__("__ceil__")

    # [Context managers]

    def __enter__(self):
        return self.__getattr__("__enter__")

    def __exit__(self, *args: object, **kwargs: object):
        return self.__getattr__("__exit__")


class PlaceholderModule(_PlaceholderBase):
1786
1787
1788
1789
1790
1791
    """
    A placeholder object to use when a module does not exist.

    This enables more informative errors when trying to access attributes
    of a module that does not exists.
    """
1792
1793
1794
1795
1796
1797

    def __init__(self, name: str) -> None:
        super().__init__()

        # Apply name mangling to avoid conflicting with module attributes
        self.__name = name
1798
1799
1800
1801
1802

    def placeholder_attr(self, attr_path: str):
        return _PlaceholderModuleAttr(self, attr_path)

    def __getattr__(self, key: str):
1803
        name = self.__name
1804
1805

        try:
1806
            importlib.import_module(name)
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
        except ImportError as exc:
            for extra, names in get_vllm_optional_dependencies().items():
                if name in names:
                    msg = f"Please install vllm[{extra}] for {extra} support"
                    raise ImportError(msg) from exc

            raise exc

        raise AssertionError("PlaceholderModule should not be used "
                             "when the original module can be imported")


1819
1820
1821
1822
1823
1824
1825
1826
class _PlaceholderModuleAttr(_PlaceholderBase):

    def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
        super().__init__()

        # Apply name mangling to avoid conflicting with module attributes
        self.__module = module
        self.__attr_path = attr_path
1827
1828

    def placeholder_attr(self, attr_path: str):
1829
1830
        return _PlaceholderModuleAttr(self.__module,
                                      f"{self.__attr_path}.{attr_path}")
1831
1832

    def __getattr__(self, key: str):
1833
        getattr(self.__module, f"{self.__attr_path}.{key}")
1834
1835
1836
1837
1838

        raise AssertionError("PlaceholderModule should not be used "
                             "when the original module can be imported")


1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT")  # noqa


def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: List[str],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
1849
    dispatch_key: str = "CUDA",
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
):
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
1866
    if is_in_doc_build():
1867
        return
1868
1869

    if not supports_custom_op():
1870
        from vllm.platforms import current_platform
1871
1872
1873
1874
1875
1876
1877
1878
        assert not current_platform.is_cuda_alike(), (
            "cuda platform needs torch>=2.4 to support custom op, "
            "chances are you are using an old version of pytorch "
            "or a custom build of pytorch. It is recommended to "
            "use vLLM in a fresh new environment and let it install "
            "the required dependencies.")
        return

1879
1880
1881
1882
1883
1884
1885
1886
    import torch.library
    if hasattr(torch.library, "infer_schema"):
        schema_str = torch.library.infer_schema(op_func,
                                                mutates_args=mutates_args)
    else:
        # for pytorch 2.4
        import torch._custom_op.impl
        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
1887
1888
    my_lib = target_lib or vllm_lib
    my_lib.define(op_name + schema_str)
1889
    my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
1890
1891
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
1892
1893
1894
1895
1896
1897
1898
1899
1900


def resolve_obj_by_qualname(qualname: str) -> Any:
    """
    Resolve an object by its fully qualified name.
    """
    module_name, obj_name = qualname.rsplit(".", 1)
    module = importlib.import_module(module_name)
    return getattr(module, obj_name)
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925


def kill_process_tree(pid: int):
    """
    Kills all descendant processes of the given pid by sending SIGKILL.

    Args:
        pid (int): Process ID of the parent process
    """
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    # Get all children recursively
    children = parent.children(recursive=True)

    # Send SIGKILL to all children first
    for child in children:
        with contextlib.suppress(ProcessLookupError):
            os.kill(child.pid, signal.SIGKILL)

    # Finally kill the parent
    with contextlib.suppress(ProcessLookupError):
        os.kill(pid, signal.SIGKILL)
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935


@dataclass
class MemorySnapshot:
    """Memory snapshot."""
    torch_peak_in_bytes: int = 0
    torch_memory_in_bytes: int = 0
    timestamp: float = 0.0

    def measure(self):
1936
1937
1938
1939
        self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
        # torch.cuda.memory_reserved() is how many bytes
        # PyTorch gets from cuda (by calling cudaMalloc, etc.)
        self.torch_memory_in_bytes = torch.cuda.memory_reserved()
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
        self.timestamp = time.time()

    def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
        """support a - b"""
        return MemorySnapshot(
            torch_peak_in_bytes=self.torch_peak_in_bytes -
            other.torch_peak_in_bytes,
            torch_memory_in_bytes=self.torch_memory_in_bytes -
            other.torch_memory_in_bytes,
            timestamp=self.timestamp - other.timestamp)


@dataclass
class MemoryProfilingResult:
    """Memory profiling result.
    """  # noqa
    baseline_memory_in_bytes: int = 0
    non_kv_cache_memory_in_bytes: int = 0
    torch_peak_increase_in_bytes: int = 0
    non_torch_increase_in_bytes: int = 0
    weights_memory_in_bytes: float = 0
    before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
    after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
    profile_time: float = 0.0


@contextlib.contextmanager
def memory_profiling(
    baseline_memory_in_bytes: int, weights_memory_in_bytes: int
) -> Generator[MemoryProfilingResult, None, None]:
    """Memory profiling context manager.
    baseline_memory_in_bytes: memory used by all the components other than
        the current vLLM instance. It contains: memory used by other processes, memory
        used by another vLLM instance in the same process, etc. It is usually measured
        before the current vLLM instance initialize the device. And we assume it is
        constant during the profiling of the current vLLM instance.
    weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
        Note that, before loading the model weights, we also initialize the device
        and distributed environment, which may consume some memory. This part is not
        included in the weights_memory_in_bytes because PyTorch does not control it.

    The memory in one GPU can be classified into 3 categories:
    1. memory used by anything other than the current vLLM instance.
    2. memory used by torch in the current vLLM instance.
    3. memory used in the current vLLM instance, but not by torch.

    A quantitive example:

    Before creating the current vLLM instance:
        category 1: 1 GiB
        category 2: 0 GiB
        category 3: 0 GiB

    After creating the current vLLM instance and loading the model,
    (i.e. before profiling):
        category 1: 1 GiB
        category 2: 2 GiB (model weights take 2 GiB)
        category 3: 0.5 GiB (memory used by NCCL)

    During profiling (peak):
        category 1: 1 GiB
        category 2: 4 GiB (peak activation tensors take 2 GiB)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    After profiling:
        category 1: 1 GiB
        category 2: 3 GiB (after garbage-collecting activation tensors)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    In this case, non-kv cache takes 5 GiB in total, including:
    a. 2 GiB used by the model weights (category 2)
    b. 2 GiB reserved for the peak activation tensors (category 2)
    c. 1 GiB used by non-torch components (category 3)

    The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.

2016
    The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
2017
2018

    (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
2019
    subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
    """ # noqa
    torch.cuda.reset_peak_memory_stats()

    result = MemoryProfilingResult()

    result.baseline_memory_in_bytes = baseline_memory_in_bytes
    # the part of memory used for holding the model weights
    result.weights_memory_in_bytes = weights_memory_in_bytes

    result.before_profile.measure()

    yield result

    gc.collect()
    torch.cuda.empty_cache()

    result.after_profile.measure()

    diff = result.after_profile - result.before_profile
    result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
    current_cuda_memory_bytes = torch.cuda.mem_get_info(
    )[1] - torch.cuda.mem_get_info()[0]
    result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes  # noqa
    result.profile_time = diff.timestamp
    result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes  # noqa
2045
2046


2047
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type,
                               (target_soft_limit, current_hard))
        except ValueError as e:
            logger.warning(
                "Found ulimit of %s and failed to automatically increase"
                "with error %s. This can cause fd limit errors like"
                "`OSError: [Errno 24] Too many open files`. Consider "
                "increasing with ulimit -n", current_soft, e)
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
    ctx: Union[zmq.asyncio.Context, zmq.Context],  # type: ignore[name-defined]
    path: str,
    type: Any,
) -> Union[zmq.Socket, zmq.asyncio.Socket]:  # type: ignore[name-defined]
    """Make a ZMQ socket with the proper bind/connect semantics."""

    mem = psutil.virtual_memory()
    socket = ctx.socket(type)

    # Calculate buffer size based on system memory
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    # For systems with substantial memory (>32GB total, >16GB available):
    # - Set a large 0.5GB buffer to improve throughput
    # For systems with less memory:
    # - Use system default (-1) to avoid excessive memory consumption
    if total_mem > 32 and available_mem > 16:
        buf_size = int(0.5 * 1024**3)  # 0.5GB in bytes
    else:
        buf_size = -1  # Use system default buffer size

    if type == zmq.constants.PULL:
        socket.setsockopt(zmq.constants.RCVHWM, 0)
        socket.setsockopt(zmq.constants.RCVBUF, buf_size)
        socket.connect(path)
    elif type == zmq.constants.PUSH:
        socket.setsockopt(zmq.constants.SNDHWM, 0)
        socket.setsockopt(zmq.constants.SNDBUF, buf_size)
        socket.bind(path)
    else:
        raise ValueError(f"Unknown Socket Type: {type}")

    return socket


@contextlib.contextmanager
def zmq_socket_ctx(
        path: str,
        type: Any) -> Iterator[zmq.Socket]:  # type: ignore[name-defined]
    """Context manager for a ZMQ socket"""

    ctx = zmq.Context(io_threads=2)  # type: ignore[attr-defined]
    try:
        yield make_zmq_socket(ctx, path, type)

    except KeyboardInterrupt:
        logger.debug("Got Keyboard Interrupt.")

    finally:
        ctx.destroy(linger=0)


def _check_multiproc_method():
    if (cuda_is_initialized()
            and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
        logger.warning("CUDA was previously initialized. We must use "
                       "the `spawn` multiprocessing start method. Setting "
                       "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
                       "See https://docs.vllm.ai/en/latest/getting_started/"
2132
                       "troubleshooting.html#python-multiprocessing "
2133
2134
2135
2136
2137
2138
2139
2140
                       "for more information.")
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
    _check_multiproc_method()
    mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
    return multiprocessing.get_context(mp_method)