utils.py 52.3 KB
Newer Older
1
import argparse
2
import asyncio
3
import contextlib
4
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
import gc
7
import getpass
8
import importlib.util
9
import inspect
10
import ipaddress
11
import os
12
import socket
13
import subprocess
14
import sys
15
16
import tempfile
import threading
17
import time
Zhuohan Li's avatar
Zhuohan Li committed
18
import uuid
19
import warnings
20
import weakref
21
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
22
from collections.abc import Mapping
23
from functools import lru_cache, partial, wraps
24
from platform import uname
25
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
26
27
                    Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
                    Type, TypeVar, Union, overload)
28
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
29

30
import numpy as np
31
import numpy.typing as npt
32
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
33
import torch
34
import torch.types
35
import yaml
36
from packaging.version import Version
37
from torch.library import Library
38
from typing_extensions import ParamSpec, TypeIs, assert_never
39

40
import vllm.envs as envs
41
from vllm.logger import enable_trace_function_call, init_logger
42
from vllm.platforms import current_platform
43
44
45

logger = init_logger(__name__)

46
47
# Exception strings for non-implemented encoder/decoder scenarios

48
49
50
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
STR_NOT_IMPL_ENC_DEC_SWA = \
    "Sliding window attention for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
    "Prefix caching for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
    "Chunked prefill for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
    "Models with logits_soft_cap "
    "require FlashInfer backend, which is "
    "currently not supported for encoder/decoder "
    "models.")

STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
                             "supported with encoder/decoder "
                             "models.")

STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
                           "currently supported with "
                           "encoder/decoder models.")

STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
                           "supported with encoder/decoder "
                           "models.")

STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
                                 "currently supported with encoder/"
                                 "decoder models.")

85
86
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
                                "backends currently supported with encoder/"
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
                                "decoder models.")

STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
                                       "currently supported with encoder/"
                                       "decoder models.")

# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
    "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
    "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
    "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
    STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
    "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
    "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
    "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
    "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
    "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
    "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
    "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}

# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

125
126
127
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""

128
129
130
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

131
132
133
134
STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
135
    "fp8": torch.uint8,
136
137
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
138
}
Zhuohan Li's avatar
Zhuohan Li committed
139

140
141
142
143
144
145
146
147
148
TORCH_DTYPE_TO_NUMPY_DTYPE = {
    torch.float16: np.float16,
    torch.float32: np.float32,
    torch.float64: np.float64,
    torch.uint8: np.uint8,
    torch.int32: np.int32,
    torch.int64: np.int64,
}

149
150
151
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
152
U = TypeVar("U")
153

Woosuk Kwon's avatar
Woosuk Kwon committed
154

155
156
157
158
159
160
161
class _Sentinel:
    ...


ALL_PINNED_SENTINEL = _Sentinel()


Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
165
166
167
168
169
170
171
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
172
    def __next__(self) -> int:
173
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        self.counter += 1
175
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
179

180

181
class LRUCache(Generic[T]):
182
183

    def __init__(self, capacity: int):
184
        self.cache: OrderedDict[Hashable, T] = OrderedDict()
185
        self.pinned_items: Set[Hashable] = set()
186
187
188
189
190
191
192
193
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

194
195
196
197
    def __getitem__(self, key: Hashable) -> T:
        value = self.cache[key]  # Raise KeyError if not exists
        self.cache.move_to_end(key)
        return value
198

199
    def __setitem__(self, key: Hashable, value: T) -> None:
200
201
202
203
204
205
206
207
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

208
209
210
    def get(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
211
        value: Optional[T]
212
        if key in self.cache:
213
            value = self.cache[key]
214
215
216
217
218
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

219
    def put(self, key: Hashable, value: T) -> None:
220
221
222
223
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

224
225
226
227
228
229
230
231
232
233
234
235
    def pin(self, key: Hashable) -> None:
        """
        Pins a key in the cache preventing it from being
        evicted in the LRU order.
        """
        if key not in self.cache:
            raise ValueError(f"Cannot pin key: {key} not in cache.")
        self.pinned_items.add(key)

    def _unpin(self, key: Hashable) -> None:
        self.pinned_items.remove(key)

236
    def _on_remove(self, key: Hashable, value: Optional[T]):
237
238
        pass

239
    def remove_oldest(self, remove_pinned=False):
240
241
        if not self.cache:
            return
242
243
244
245
246
247
248
249
250
251
252
253

        if not remove_pinned:
            # pop the oldest item in the cache that is not pinned
            lru_key = next(
                (key for key in self.cache if key not in self.pinned_items),
                ALL_PINNED_SENTINEL)
            if lru_key is ALL_PINNED_SENTINEL:
                raise RuntimeError("All items are pinned, "
                                   "cannot remove oldest from the cache.")
        else:
            lru_key = next(iter(self.cache))
        self.pop(lru_key)
254
255
256
257
258

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

259
260
261
    def pop(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
262
        run_on_remove = key in self.cache
263
        value: Optional[T] = self.cache.pop(key, default_value)
264
265
266
        # remove from pinned items
        if key in self.pinned_items:
            self._unpin(key)
267
268
269
270
271
272
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
273
            self.remove_oldest(remove_pinned=True)
274
275
276
        self.cache.clear()


277
class PyObjectCache:
278
    """Used to cache python objects to avoid object allocations
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    across scheduler iterations.
    """

    def __init__(self, obj_builder):
        self._obj_builder = obj_builder
        self._index = 0

        self._obj_cache = []
        for _ in range(128):
            self._obj_cache.append(self._obj_builder())

    def _grow_cache(self):
        # Double the size of the cache
        num_objs = len(self._obj_cache)
        for _ in range(num_objs):
            self._obj_cache.append(self._obj_builder())

    def get_object(self):
297
        """Returns a pre-allocated cached object. If there is not enough
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        objects, then the cache size will double.
        """
        if self._index >= len(self._obj_cache):
            self._grow_cache()
            assert self._index < len(self._obj_cache)

        obj = self._obj_cache[self._index]
        self._index += 1

        return obj

    def reset(self):
        """Makes all cached-objects available for the next scheduler iteration.
        """
        self._index = 0


315
@lru_cache(maxsize=None)
316
317
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
318
    from vllm import _custom_ops as ops
319
    max_shared_mem = (
320
        ops.get_max_shared_memory_per_block_device_attribute(gpu))
321
322
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
323
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
324
325
326
    return int(max_shared_mem)


327
def get_cpu_memory() -> int:
328
    """Returns the total CPU memory of the node in bytes."""
329
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
330
331
332
333


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
334

335

336
@lru_cache(maxsize=None)
337
def get_vllm_instance_id() -> str:
338
339
340
341
342
343
    """
    If the environment variable VLLM_INSTANCE_ID is set, return it.
    Otherwise, return a random UUID.
    Instance id represents an instance of the VLLM. All processes in the same
    instance should have the same instance id.
    """
344
    return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
345
346


347
@lru_cache(maxsize=None)
348
349
350
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
351
352


353
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
354
355
356
357
358
359
360
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

361
    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
362
363
364
365
366
367
368
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


369
370
371
372
373
374
def _next_task(iterator: AsyncGenerator[T, None],
               loop: AbstractEventLoop) -> Task:
    # Can use anext() in python >= 3.10
    return loop.create_task(iterator.__anext__())  # type: ignore[arg-type]


375
376
377
378
379
380
381
382
async def iterate_with_cancellation(
    iterator: AsyncGenerator[T, None],
    is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
    """Convert async iterator into one that polls the provided function
    at least once per second to check for client cancellation.
    """

383
384
385
386
    loop = asyncio.get_running_loop()

    awaits: List[Future[T]] = [_next_task(iterator, loop)]
    next_cancel_check: float = 0
387
    while True:
388
389
390
391
392
393
394
395
396
397
398
399
        done, pending = await asyncio.wait(awaits, timeout=1.5)

        # Check for cancellation at most once per second
        time_now = time.time()
        if time_now >= next_cancel_check:
            if await is_cancelled():
                with contextlib.suppress(BaseException):
                    awaits[0].cancel()
                    await iterator.aclose()
                raise asyncio.CancelledError("client cancelled")
            next_cancel_check = time_now + 1

400
401
402
        if done:
            try:
                item = await awaits[0]
403
                awaits[0] = _next_task(iterator, loop)
404
405
406
407
                yield item
            except StopAsyncIteration:
                # we are done
                return
408
409


410
411
async def merge_async_iterators(
    *iterators: AsyncGenerator[T, None],
412
    is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
413
) -> AsyncGenerator[Tuple[int, T], None]:
414
415
416
417
418
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
419

420
421
    It also optionally polls a provided function at least once per second
    to check for client cancellation.
422
    """
423

424
425
426
427
428
    loop = asyncio.get_running_loop()

    awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
    timeout = None if is_cancelled is None else 1.5
    next_cancel_check: float = 0
429
430
431
432
    try:
        while awaits:
            done, pending = await asyncio.wait(awaits.keys(),
                                               return_when=FIRST_COMPLETED,
433
                                               timeout=timeout)
434
435
436
437
438
439
440
            if is_cancelled is not None:
                # Check for cancellation at most once per second
                time_now = time.time()
                if time_now >= next_cancel_check:
                    if await is_cancelled():
                        raise asyncio.CancelledError("client cancelled")
                    next_cancel_check = time_now + 1
441
442
443
444
445
            for d in done:
                pair = awaits.pop(d)
                try:
                    item = await d
                    i, it = pair
446
                    awaits[_next_task(it, loop)] = pair
447
448
449
450
451
452
453
454
455
                    yield i, item
                except StopAsyncIteration:
                    pass
    finally:
        # Cancel any remaining iterators
        for f, (_, it) in awaits.items():
            with contextlib.suppress(BaseException):
                f.cancel()
                await it.aclose()
456
457


458
459
460
461
462
463
464
465
466
async def collect_from_async_generator(
        iterator: AsyncGenerator[T, None]) -> List[T]:
    """Collect all items from an async generator into a list."""
    items = []
    async for item in iterator:
        items.append(item)
    return items


467
def get_ip() -> str:
468
    host_ip = envs.VLLM_HOST_IP
469
470
471
472
473
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

474
    # try ipv4
475
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
476
    try:
477
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
478
        return s.getsockname()[0]
479
480
481
482
483
    except Exception:
        pass

    # try ipv6
    try:
484
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
485
486
487
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
488
        return s.getsockname()[0]
489
490
491
492
493
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
494
495
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
496
497
        stacklevel=2)
    return "0.0.0.0"
498
499


500
501
502
503
504
505
506
507
def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False


508
def get_distributed_init_method(ip: str, port: int) -> str:
509
510
511
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
512
513


514
515
516
517
518
519
520
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


def get_open_port() -> int:
    port = envs.VLLM_PORT
521
    if port is not None:
522
523
524
525
526
527
528
529
530
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d",
                            port - 1, port)
531
532
533
534
535
536
537
538
539
540
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
541
542


543
544
545
546
547
548
549
550
551
552
def find_process_using_port(port: int) -> Optional[psutil.Process]:
    for conn in psutil.net_connections():
        if conn.laddr.port == port:
            try:
                return psutil.Process(conn.pid)
            except psutil.NoSuchProcess:
                return None
    return None


553
554
def update_environment_variables(envs: Dict[str, str]):
    for k, v in envs.items():
555
        if k in os.environ and os.environ[k] != v:
556
557
558
            logger.warning(
                "Overwriting environment variable %s "
                "from '%s' to '%s'", k, os.environ[k], v)
559
        os.environ[k] = v
560
561


562
def chunk_list(lst: List[T], chunk_size: int):
563
    """Yield successive chunk_size chunks from lst."""
564
565
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]
566
567
568
569
570
571
572


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


573
def _generate_random_fp8(
574
    tensor: torch.Tensor,
575
576
577
578
579
580
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
581
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
582
583
584
585
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
586
    from vllm import _custom_ops as ops
587
588
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
589
    ops.convert_fp8(tensor, tensor_tmp)
590
591
592
    del tensor_tmp


593
594
595
def get_kv_cache_torch_dtype(
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
596
597
598
599
600
601
602
603
604
605
    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
606
        elif cache_dtype == "fp8":
607
608
609
610
611
612
613
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    return torch_dtype


def create_kv_caches_with_random_flash(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
628
    current_platform.seed_everything(seed)
629
630
631
632

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
    key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
    scale = head_size**-0.5
633
634
635
636

    key_caches: List[torch.Tensor] = []
    value_caches: List[torch.Tensor] = []

637
638
639
640
    for _ in range(num_layers):
        key_value_cache = torch.empty(size=key_value_cache_shape,
                                      dtype=torch_dtype,
                                      device=device)
641
642
643
644
645
646
647
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
            key_value_cache.uniform_(-scale, scale)
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_value_cache, -scale, scale)
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
        key_caches.append(key_value_cache[:, 0])
        value_caches.append(key_value_cache[:, 1])
    return key_caches, value_caches


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
Joe's avatar
Joe committed
664
665
666
667
668
669

    if cache_dtype == "fp8" and head_size % 16:
        raise ValueError(
            f"Does not support key cache of type fp8 with head_size {head_size}"
        )

670
    current_platform.seed_everything(seed)
671
672

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
673
674
675
676

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
677
    key_caches: List[torch.Tensor] = []
678
679
680
681
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
682
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
683
            key_cache.uniform_(-scale, scale)
684
685
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
686
687
688
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
689
690
691
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
692
    value_caches: List[torch.Tensor] = []
693
694
695
696
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
697
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
698
            value_cache.uniform_(-scale, scale)
699
700
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
701
702
703
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
704
705
        value_caches.append(value_cache)
    return key_caches, value_caches
706
707


708
709
@lru_cache
def print_warning_once(msg: str) -> None:
710
711
    # Set the stacklevel to 2 to print the caller's line info
    logger.warning(msg, stacklevel=2)
712
713
714
715
716
717
718
719
720
721
722


@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:

    if in_wsl():
        # Pinning memory in WSL is not supported.
        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
        return False
723
    elif current_platform.is_xpu():
724
725
        print_warning_once("Pin memory is not supported on XPU.")
        return False
726
    elif current_platform.is_neuron():
727
728
        print_warning_once("Pin memory is not supported on Neuron.")
        return False
729
730
731
    elif current_platform.is_hpu():
        print_warning_once("Pin memory is not supported on HPU.")
        return False
732
    elif current_platform.is_cpu() or current_platform.is_openvino():
733
        return False
734
735
736
    return True


737
class DeviceMemoryProfiler:
738

739
    def __init__(self, device: Optional[torch.types.Device] = None):
740
741
742
743
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
744
        if current_platform.is_cuda_alike():
745
746
            torch.cuda.reset_peak_memory_stats(self.device)
            mem = torch.cuda.max_memory_allocated(self.device)
747
        elif current_platform.is_xpu():
748
749
            torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
            mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
750
751
752
753
754
755
756
757
758
759
760
761
762
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
763
764


765
766
767
768
769
770
771
772
773
def make_ndarray_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: npt.DTypeLike,
    *,
    max_len: Optional[int] = None,
) -> npt.NDArray:
    """
    Make a padded array from 2D inputs.
774
775
776
777

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
778
779
780
781
782
    if max_len is None:
        # Unlike for most functions, map is faster than a genexpr over `len`
        max_len = max(map(len, x), default=0)

    padded_x = np.full((len(x), max_len), pad, dtype=dtype)
783
784
785
    for ind, blocktb in enumerate(x):
        assert len(blocktb) <= max_len
        padded_x[ind, :len(blocktb)] = blocktb
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812

    return padded_x


def make_tensor_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: torch.dtype,
    *,
    max_len: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    pin_memory: bool = False,
) -> torch.Tensor:
    """
    Make a padded tensor from 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
    padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

    tensor = torch.from_numpy(padded_x).to(device)
    if pin_memory:
        tensor = tensor.pin_memory()

    return tensor
813
814
815
816
817
818
819
820
821
822
823
824
825


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


826
827
828
829
830
def get_dtype_size(dtype: torch.dtype) -> int:
    """Get the size of the data type in bytes."""
    return torch.tensor([], dtype=dtype).element_size()


831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
# `collections` helpers
def is_list_of(
    value: object,
    typ: Type[T],
    *,
    check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
    if not isinstance(value, list):
        return False

    if check == "first":
        return len(value) == 0 or isinstance(value[0], typ)
    elif check == "all":
        return all(isinstance(v, typ) for v in value)

    assert_never(check)


849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
                 Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: JSONTree[T],
) -> JSONTree[U]:
    ...


def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
    if isinstance(value, dict):
        return {k: json_map_leaves(func, v) for k, v in value.items()}
    elif isinstance(value, list):
        return [json_map_leaves(func, v) for v in value]
    elif isinstance(value, tuple):
        return tuple(json_map_leaves(func, v) for v in value)
    else:
        return func(value)


897
898
899
900
901
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
    """Flatten a list of lists to a single list."""
    return [item for sublist in lists for item in sublist]


902
903
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
904
def init_cached_hf_modules() -> None:
905
906
907
908
909
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()
910
911


912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
928
    env_ld_library_path = envs.LD_LIBRARY_PATH
929
930
931
932
933
934
935
936
937
938
939
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


940
def find_nccl_library() -> str:
941
942
943
944
945
946
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
947
    so_file = envs.VLLM_NCCL_SO_PATH
948
949
950
951

    # manually load the nccl library
    if so_file:
        logger.info(
952
953
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
            so_file)
954
955
    else:
        if torch.version.cuda is not None:
956
            so_file = "libnccl.so.2"
957
        elif torch.version.hip is not None:
958
            so_file = "librccl.so.1"
959
960
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
961
        logger.info("Found nccl from library %s", so_file)
962
    return so_file
963
964
965
966
967
968
969


def enable_trace_function_call_for_thread() -> None:
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

970
    if envs.VLLM_TRACE_FUNCTION:
971
        tmp_dir = tempfile.gettempdir()
972
973
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
974
975
976
977
978
979
980
        filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                    f"_thread_{threading.get_ident()}_"
                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
        log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                filename)
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
981
982


983
# `functools` helpers
984
985
def identity(value: T, **kwargs) -> T:
    """Returns the first provided value."""
986
987
988
989
990
991
    return value


F = TypeVar('F', bound=Callable[..., Any])


992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def deprecate_args(
    start_index: int,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        params = inspect.signature(fn).parameters
        pos_types = (
            inspect.Parameter.POSITIONAL_ONLY,
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
        )
        pos_kws = [
            kw for kw, param in params.items() if param.kind in pos_types
        ]

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_args = pos_kws[start_index:len(args)]
                if deprecated_args:
                    msg = (
                        f"The positional arguments {deprecated_args} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper


1035
def deprecate_kwargs(
1036
1037
1038
1039
    *kws: str,
    is_deprecated: Union[bool, Callable[[], bool]] = True,
    additional_message: Optional[str] = None,
) -> Callable[[F], F]:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    deprecated_kws = set(kws)

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_kwargs = kwargs.keys() & deprecated_kws
                if deprecated_kwargs:
                    msg = (
                        f"The keyword arguments {deprecated_kwargs} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
        cuda_visible_devices: Optional[str] = None) -> int:
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda
    import torch.version

    if not torch.cuda._is_compiled():
        return 0
1085
    if current_platform.is_rocm():
1086
1087
1088
1089
1090
1091
1092
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
        raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
            torch.cuda, "_device_count_amdsmi")) else -1
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1093
1094
1095
1096
1097
1098
    return r


def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.
1099

1100
1101
1102
1103
1104
1105
1106
    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
1107
1108


1109
1110
1111
1112
1113
1114
1115
def cuda_is_initialized() -> bool:
    """Check if CUDA is initialized."""
    if not torch.cuda._is_compiled():
        return False
    return torch.cuda.is_initialized()


1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


1130
#From: https://stackoverflow.com/a/4104188/2749989
1131
def run_once(f: Callable[P, None]) -> Callable[P, None]:
1132

1133
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
1134
1135
1136
1137
1138
1139
        if not wrapper.has_run:  # type: ignore[attr-defined]
            wrapper.has_run = True  # type: ignore[attr-defined]
            return f(*args, **kwargs)

    wrapper.has_run = False  # type: ignore[attr-defined]
    return wrapper
1140
1141


1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1154
1155
1156
1157
1158
class SortedHelpFormatter(argparse.HelpFormatter):
    """SortedHelpFormatter that sorts arguments by their option strings."""

    def add_arguments(self, actions):
        actions = sorted(actions, key=lambda x: x.option_strings)
1159
        super().add_arguments(actions)
1160
1161


1162
1163
1164
class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

1165
1166
1167
1168
1169
1170
    def __init__(self, *args, **kwargs):
        # Set the default 'formatter_class' to SortedHelpFormatter
        if 'formatter_class' not in kwargs:
            kwargs['formatter_class'] = SortedHelpFormatter
        super().__init__(*args, **kwargs)

1171
1172
1173
1174
    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

1175
        if '--config' in args:
1176
            args = self._pull_args_from_config(args)
1177

1178
1179
1180
1181
        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith('--'):
1182
1183
1184
1185
1186
1187
1188
                if '=' in arg:
                    key, value = arg.split('=', 1)
                    key = '--' + key[len('--'):].replace('_', '-')
                    processed_args.append(f'{key}={value}')
                else:
                    processed_args.append('--' +
                                          arg[len('--'):].replace('_', '-'))
1189
1190
1191
1192
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)
1193

1194
    def _pull_args_from_config(self, args: List[str]) -> List[str]:
1195
1196
        """Method to pull arguments specified in the config file
        into the command-line args variable.
1197
1198

        The arguments in config file will be inserted between
1199
        the argument list.
1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        example:
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        ```python
        $: vllm {serve,chat,complete} "facebook/opt-12B" \
            --config config.yaml -tp 2
        $: args = [
            "serve,chat,complete",
1211
1212
            "facebook/opt-12B",
            '--config', 'config.yaml',
1213
1214
1215
1216
            '-tp', '2'
        ]
        $: args = [
            "serve,chat,complete",
1217
1218
1219
            "facebook/opt-12B",
            '--port', '12323',
            '--tensor-parallel-size', '4',
1220
1221
1222
1223
1224
            '-tp', '2'
            ]
        ```

        Please note how the config args are inserted after the sub command.
1225
        this way the order of priorities is maintained when these are args
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
        parsed by super().
        """
        assert args.count(
            '--config') <= 1, "More than one config file specified!"

        index = args.index('--config')
        if index == len(args) - 1:
            raise ValueError("No config file specified! \
                             Please check your command-line arguments.")

        file_path = args[index + 1]

1238
        config_args = self._load_config_file(file_path)
1239
1240

        # 0th index is for {serve,chat,complete}
1241
        # followed by model_tag (only for serve)
1242
1243
1244
1245
        # followed by config args
        # followed by rest of cli args.
        # maintaining this order will enforce the precedence
        # of cli > config > defaults
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        if args[0] == "serve":
            if index == 1:
                raise ValueError(
                    "No model_tag specified! Please check your command-line"
                    " arguments.")
            args = [args[0]] + [
                args[1]
            ] + config_args + args[2:index] + args[index + 2:]
        else:
            args = [args[0]] + config_args + args[1:index] + args[index + 2:]
1256
1257
1258

        return args

1259
    def _load_config_file(self, file_path: str) -> List[str]:
1260
        """Loads a yaml file and returns the key value pairs as a
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        flattened list with argparse like pattern
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        returns:
            processed_args: list[str] = [
                '--port': '12323',
                '--tensor-parallel-size': '4'
            ]
1271

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        """

        extension: str = file_path.split('.')[-1]
        if extension not in ('yaml', 'yml'):
            raise ValueError(
                "Config file must be of a yaml/yml type.\
                              %s supplied", extension)

        # only expecting a flat dictionary of atomic types
        processed_args: List[str] = []

        config: Dict[str, Union[int, str]] = {}
        try:
1285
            with open(file_path) as config_file:
1286
1287
1288
1289
1290
1291
1292
                config = yaml.safe_load(config_file)
        except Exception as ex:
            logger.error(
                "Unable to read the config file at %s. \
                Make sure path is correct", file_path)
            raise ex

1293
1294
1295
1296
1297
        store_boolean_arguments = [
            action.dest for action in self._actions
            if isinstance(action, StoreBoolean)
        ]

1298
        for key, value in config.items():
1299
1300
1301
1302
1303
1304
            if isinstance(value, bool) and key not in store_boolean_arguments:
                if value:
                    processed_args.append('--' + key)
            else:
                processed_args.append('--' + key)
                processed_args.append(str(value))
1305
1306
1307

        return processed_args

1308
1309
1310
1311
1312
1313

async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
                              **kwargs):
    """Utility function to run async task in a lock"""
    async with lock:
        return await task(*args, **kwargs)
1314
1315


1316
1317
1318
1319
1320
1321
1322
1323
1324
def supports_kw(
    callable: Callable[..., object],
    kw_name: str,
    requires_kw_only: bool = False,
    allow_var_kwargs: bool = True,
) -> bool:
    """Check if a keyword is a valid kwarg for a callable; if requires_kw_only
    disallows kwargs names that can also be positional arguments.
    """
1325
    params = inspect.signature(callable).parameters
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
    if not params:
        return False

    param_val = params.get(kw_name)

    # Types where the it may be valid, i.e., explicitly defined & nonvariadic
    passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
                             inspect.Parameter.POSITIONAL_OR_KEYWORD,
                             inspect.Parameter.KEYWORD_ONLY))

    if param_val:
        is_sig_param = param_val.kind in passable_kw_types
        # We want kwargs only, but this is passable as a positional arg
        if (requires_kw_only and is_sig_param
                and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
            return False
        if ((requires_kw_only
             and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
                or (not requires_kw_only and is_sig_param)):
            return True

    # If we're okay with var-kwargs, it's supported as long as
    # the kw_name isn't something like *args, **kwargs
    if allow_var_kwargs:
        # Get the last param; type is ignored here because params is a proxy
        # mapping, but it wraps an ordered dict, and they appear in order.
        # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
        last_param = params[next(reversed(params))]  # type: ignore
        return (last_param.kind == inspect.Parameter.VAR_KEYWORD
                and last_param.name != kw_name)
    return False


def resolve_mm_processor_kwargs(
    init_kwargs: Optional[Dict[str, Any]],
    inference_kwargs: Optional[Dict[str, Any]],
    callable: Callable[..., object],
    allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
    """Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
    those who are not explicit keywords to the given callable (of one is
    given; otherwise no filtering is done), then merges the kwarg dicts,
    giving priority to inference_kwargs if there are any collisions.

    In the case that no kwarg overrides are provided, returns an empty
    dict so that it can still be kwarg expanded into the callable later on.

    If allow_var_kwargs=True, allows for things that can be expanded into
    kwargs as long as they aren't naming collision for var_kwargs or potential
    positional arguments.
    """
    # Filter inference time multimodal processor kwargs provided
    runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
        callable,
        overrides=inference_kwargs,
        allow_var_kwargs=allow_var_kwargs)

    # Filter init time multimodal processor kwargs provided
    init_mm_kwargs = get_allowed_kwarg_only_overrides(
        callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
1386

1387
1388
1389
1390
    # Merge the final processor kwargs, prioritizing inference
    # time values over the initialization time values.
    mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
    return mm_processor_kwargs
1391
1392


1393
1394
1395
def get_allowed_kwarg_only_overrides(
    callable: Callable[..., object],
    overrides: Optional[Dict[str, Any]],
1396
    allow_var_kwargs: bool = False,
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
) -> Dict[str, Any]:
    """
    Given a callable which has one or more keyword only params and a dict
    mapping param names to values, drop values that can be not be kwarg
    expanded to overwrite one or more keyword-only args. This is used in a
    few places to handle custom processor overrides for multimodal models,
    e.g., for profiling when processor options provided by the user
    may affect the number of mm tokens per instance.

    Args:
        callable: Callable which takes 0 or more keyword only arguments.
1408
                  If None is provided, all overrides names are allowed.
1409
        overrides: Potential overrides to be used when invoking the callable.
1410
        allow_var_kwargs: Allows overrides that are expandable for var kwargs.
1411
1412
1413
1414
1415
1416
1417
1418
1419

    Returns:
        Dictionary containing the kwargs to be leveraged which may be used
        to overwrite one or more keyword only arguments when invoking the
        callable.
    """
    if not overrides:
        return {}

1420
1421
    # Drop any mm_processor_kwargs provided by the user that
    # are not kwargs, unless it can fit it var_kwargs param
1422
1423
1424
    filtered_overrides = {
        kwarg_name: val
        for kwarg_name, val in overrides.items()
1425
1426
1427
1428
        if supports_kw(callable,
                       kwarg_name,
                       requires_kw_only=True,
                       allow_var_kwargs=allow_var_kwargs)
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    }

    # If anything is dropped, log a warning
    dropped_keys = overrides.keys() - filtered_overrides.keys()
    if dropped_keys:
        logger.warning(
            "The following intended overrides are not keyword-only args "
            "and and will be dropped: %s", dropped_keys)

    return filtered_overrides


1441
1442
1443
1444
1445
1446
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
    base_torch_version = Version(Version(torch.__version__).base_version)
    return base_torch_version >= Version("2.4.0")
1447
1448


1449
1450
1451
1452
1453
1454
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498


# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):

    def __init__(self, factory: Dict[str, Callable[[], T]]):
        self._factory = factory
        self._dict: Dict[str, T] = {}

    def __getitem__(self, key) -> T:
        if key not in self._dict:
            if key not in self._factory:
                raise KeyError(key)
            self._dict[key] = self._factory[key]()
        return self._dict[key]

    def __iter__(self):
        return iter(self._factory)

    def __len__(self):
        return len(self._factory)
1499
1500


1501
1502
1503
1504
1505
1506
1507
1508
1509
def combine_fx_passes(passes: List[Callable]) -> Callable:

    def combined_fx(graph) -> None:
        for fx in passes:
            fx(graph)

    return combined_fx


1510
1511
1512
1513
1514
1515
1516
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """
    Create a weak reference to a tensor.
    The new tensor will share the same data as the original tensor,
    but will not keep the original tensor alive.
    """
    return torch.ops._C.weak_ref_tensor(tensor)
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532


def weak_ref_tensors(
    tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
    """
    Convenience function to create weak references to tensors,
    for single tensor, list of tensors or tuple of tensors.
    """
    if isinstance(tensors, torch.Tensor):
        return weak_ref_tensor(tensors)
    if isinstance(tensors, list):
        return [weak_ref_tensor(t) for t in tensors]
    if isinstance(tensors, tuple):
        return tuple(weak_ref_tensor(t) for t in tensors)
    raise ValueError("Invalid type for tensors")
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542


def is_in_doc_build() -> bool:
    try:
        from sphinx.ext.autodoc.mock import _MockModule
        return isinstance(torch, _MockModule)
    except ModuleNotFoundError:
        return False


1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
    """
    Import a Python file according to its file path.

    Based on the official recipe:
    https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    """
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ModuleNotFoundError(f"No module named '{module_name}'")

    assert spec.loader is not None

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT")  # noqa


def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: List[str],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
):
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
    if is_in_doc_build():
        return
1590
1591
1592
1593
1594
1595
1596
1597
    import torch.library
    if hasattr(torch.library, "infer_schema"):
        schema_str = torch.library.infer_schema(op_func,
                                                mutates_args=mutates_args)
    else:
        # for pytorch 2.4
        import torch._custom_op.impl
        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
1598
1599
1600
1601
1602
    my_lib = target_lib or vllm_lib
    my_lib.define(op_name + schema_str)
    my_lib.impl(op_name, op_func, "CUDA")
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
1603
1604
1605
1606
1607
1608
1609
1610
1611


def resolve_obj_by_qualname(qualname: str) -> Any:
    """
    Resolve an object by its fully qualified name.
    """
    module_name, obj_name = qualname.rsplit(".", 1)
    module = importlib.import_module(module_name)
    return getattr(module, obj_name)